summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-30 14:13:51 +0200
committerVolpeon <git@volpeon.ink>2022-09-30 14:13:51 +0200
commit9a42def9fcfb9a5c5471d640253ed6c8f45c4973 (patch)
treead186862f5095663966dd1d42455023080aa0c4e /infer.py
parentBetter sample file structure (diff)
downloadtextual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.gz
textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.bz2
textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.zip
Added custom SD pipeline + euler_a scheduler
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py111
1 files changed, 80 insertions, 31 deletions
diff --git a/infer.py b/infer.py
index f2007e9..de3d792 100644
--- a/infer.py
+++ b/infer.py
@@ -1,18 +1,15 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging
3from pathlib import Path 4from pathlib import Path
4from torch import autocast 5from torch import autocast
5from diffusers import StableDiffusionPipeline
6import torch 6import torch
7import json 7import json
8from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler 8from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
9from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor 9from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
10from slugify import slugify 10from slugify import slugify
11from pipelines.stable_diffusion.no_check import NoCheck 11from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion
12 12from schedulers.scheduling_euler_a import EulerAScheduler
13model_id = "path-to-your-trained-model"
14
15prompt = "A photo of sks dog in a bucket"
16 13
17 14
18def parse_args(): 15def parse_args():
@@ -30,6 +27,21 @@ def parse_args():
30 default=None, 27 default=None,
31 ) 28 )
32 parser.add_argument( 29 parser.add_argument(
30 "--negative_prompt",
31 type=str,
32 default=None,
33 )
34 parser.add_argument(
35 "--width",
36 type=int,
37 default=512,
38 )
39 parser.add_argument(
40 "--height",
41 type=int,
42 default=512,
43 )
44 parser.add_argument(
33 "--batch_size", 45 "--batch_size",
34 type=int, 46 type=int,
35 default=1, 47 default=1,
@@ -42,17 +54,28 @@ def parse_args():
42 parser.add_argument( 54 parser.add_argument(
43 "--steps", 55 "--steps",
44 type=int, 56 type=int,
45 default=80, 57 default=120,
46 ) 58 )
47 parser.add_argument( 59 parser.add_argument(
48 "--scale", 60 "--scheduler",
61 type=str,
62 choices=["plms", "ddim", "klms", "euler_a"],
63 default="euler_a",
64 )
65 parser.add_argument(
66 "--guidance_scale",
49 type=int, 67 type=int,
50 default=7.5, 68 default=7.5,
51 ) 69 )
52 parser.add_argument( 70 parser.add_argument(
71 "--clip_guidance_scale",
72 type=int,
73 default=100,
74 )
75 parser.add_argument(
53 "--seed", 76 "--seed",
54 type=int, 77 type=int,
55 default=None, 78 default=torch.random.seed(),
56 ) 79 )
57 parser.add_argument( 80 parser.add_argument(
58 "--output_dir", 81 "--output_dir",
@@ -81,31 +104,39 @@ def save_args(basepath, args, extra={}):
81 json.dump(info, f, indent=4) 104 json.dump(info, f, indent=4)
82 105
83 106
84def main(): 107def gen(args, output_dir):
85 args = parse_args()
86
87 seed = args.seed or torch.random.seed()
88
89 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
90 output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}")
91 output_dir.mkdir(parents=True, exist_ok=True)
92 save_args(output_dir, args)
93
94 tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) 108 tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16)
95 text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) 109 text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16)
110 clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16)
96 vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) 111 vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16)
97 unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) 112 unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16)
98 feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) 113 feature_extractor = CLIPFeatureExtractor.from_pretrained(
114 "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16)
99 115
100 pipeline = StableDiffusionPipeline( 116 if args.scheduler == "plms":
117 scheduler = PNDMScheduler(
118 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
119 )
120 elif args.scheduler == "klms":
121 scheduler = LMSDiscreteScheduler(
122 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
123 )
124 elif args.scheduler == "ddim":
125 scheduler = DDIMScheduler(
126 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
127 )
128 else:
129 scheduler = EulerAScheduler(
130 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
131 )
132
133 pipeline = CLIPGuidedStableDiffusion(
101 text_encoder=text_encoder, 134 text_encoder=text_encoder,
102 vae=vae, 135 vae=vae,
103 unet=unet, 136 unet=unet,
104 tokenizer=tokenizer, 137 tokenizer=tokenizer,
105 scheduler=PNDMScheduler( 138 clip_model=clip_model,
106 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 139 scheduler=scheduler,
107 ),
108 safety_checker=NoCheck(),
109 feature_extractor=feature_extractor 140 feature_extractor=feature_extractor
110 ) 141 )
111 pipeline.enable_attention_slicing() 142 pipeline.enable_attention_slicing()
@@ -113,16 +144,34 @@ def main():
113 144
114 with autocast("cuda"): 145 with autocast("cuda"):
115 for i in range(args.batch_num): 146 for i in range(args.batch_num):
116 generator = torch.Generator(device="cuda").manual_seed(seed + i) 147 generator = torch.Generator(device="cuda").manual_seed(args.seed + i)
117 images = pipeline( 148 images = pipeline(
118 [args.prompt] * args.batch_size, 149 prompt=[args.prompt] * args.batch_size,
150 height=args.height,
151 width=args.width,
152 negative_prompt=args.negative_prompt,
119 num_inference_steps=args.steps, 153 num_inference_steps=args.steps,
120 guidance_scale=args.scale, 154 guidance_scale=args.guidance_scale,
155 clip_guidance_scale=args.clip_guidance_scale,
121 generator=generator, 156 generator=generator,
122 ).images 157 ).images
123 158
124 for j, image in enumerate(images): 159 for j, image in enumerate(images):
125 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) 160 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"))
161
162
163def main():
164 args = parse_args()
165
166 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
167 output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}")
168 output_dir.mkdir(parents=True, exist_ok=True)
169
170 save_args(output_dir, args)
171
172 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
173
174 gen(args, output_dir)
126 175
127 176
128if __name__ == "__main__": 177if __name__ == "__main__":