summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-27 16:57:29 +0100
committerVolpeon <git@volpeon.ink>2022-11-27 16:57:29 +0100
commitb9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d (patch)
tree2ad3740868696fc071d8850171e6e53ccc3a7bd2 /infer.py
parentUpdate (diff)
downloadtextual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.gz
textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.bz2
textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.zip
Update
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py52
1 files changed, 23 insertions, 29 deletions
diff --git a/infer.py b/infer.py
index 2bf9cb3..ab5f247 100644
--- a/infer.py
+++ b/infer.py
@@ -20,7 +20,6 @@ torch.backends.cuda.matmul.allow_tf32 = True
20 20
21default_args = { 21default_args = {
22 "model": None, 22 "model": None,
23 "scheduler": "dpmpp",
24 "precision": "fp32", 23 "precision": "fp32",
25 "ti_embeddings_dir": "embeddings_ti", 24 "ti_embeddings_dir": "embeddings_ti",
26 "output_dir": "output/inference", 25 "output_dir": "output/inference",
@@ -29,6 +28,7 @@ default_args = {
29 28
30 29
31default_cmds = { 30default_cmds = {
31 "scheduler": "dpmpp",
32 "prompt": None, 32 "prompt": None,
33 "negative_prompt": None, 33 "negative_prompt": None,
34 "image": None, 34 "image": None,
@@ -62,11 +62,6 @@ def create_args_parser():
62 type=str, 62 type=str,
63 ) 63 )
64 parser.add_argument( 64 parser.add_argument(
65 "--scheduler",
66 type=str,
67 choices=["plms", "ddim", "klms", "dpmpp", "euler_a"],
68 )
69 parser.add_argument(
70 "--precision", 65 "--precision",
71 type=str, 66 type=str,
72 choices=["fp32", "fp16", "bf16"], 67 choices=["fp32", "fp16", "bf16"],
@@ -92,6 +87,11 @@ def create_cmd_parser():
92 description="Simple example of a training script." 87 description="Simple example of a training script."
93 ) 88 )
94 parser.add_argument( 89 parser.add_argument(
90 "--scheduler",
91 type=str,
92 choices=["plms", "ddim", "klms", "dpmpp", "euler_a"],
93 )
94 parser.add_argument(
95 "--prompt", 95 "--prompt",
96 type=str, 96 type=str,
97 nargs="+", 97 nargs="+",
@@ -199,37 +199,17 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir):
199 print(f"Loaded {placeholder_token}") 199 print(f"Loaded {placeholder_token}")
200 200
201 201
202def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): 202def create_pipeline(model, ti_embeddings_dir, dtype):
203 print("Loading Stable Diffusion pipeline...") 203 print("Loading Stable Diffusion pipeline...")
204 204
205 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 205 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
206 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) 206 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype)
207 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 207 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype)
208 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 208 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
209 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
209 210
210 load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) 211 load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir)
211 212
212 if scheduler == "plms":
213 scheduler = PNDMScheduler(
214 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
215 )
216 elif scheduler == "klms":
217 scheduler = LMSDiscreteScheduler(
218 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
219 )
220 elif scheduler == "ddim":
221 scheduler = DDIMScheduler(
222 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
223 )
224 elif scheduler == "dpmpp":
225 scheduler = DPMSolverMultistepScheduler(
226 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
227 )
228 else:
229 scheduler = EulerAncestralDiscreteScheduler(
230 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
231 )
232
233 pipeline = VlpnStableDiffusion( 213 pipeline = VlpnStableDiffusion(
234 text_encoder=text_encoder, 214 text_encoder=text_encoder,
235 vae=vae, 215 vae=vae,
@@ -264,6 +244,17 @@ def generate(output_dir, pipeline, args):
264 else: 244 else:
265 init_image = None 245 init_image = None
266 246
247 if args.scheduler == "plms":
248 pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config)
249 elif args.scheduler == "klms":
250 pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
251 elif args.scheduler == "ddim":
252 pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
253 elif args.scheduler == "dpmpp":
254 pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
255 elif args.scheduler == "euler_a":
256 pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
257
267 with torch.autocast("cuda"), torch.inference_mode(): 258 with torch.autocast("cuda"), torch.inference_mode():
268 for i in range(args.batch_num): 259 for i in range(args.batch_num):
269 pipeline.set_progress_bar_config( 260 pipeline.set_progress_bar_config(
@@ -331,6 +322,9 @@ class CmdParse(cmd.Cmd):
331 generate(self.output_dir, self.pipeline, args) 322 generate(self.output_dir, self.pipeline, args)
332 except KeyboardInterrupt: 323 except KeyboardInterrupt:
333 print('Generation cancelled.') 324 print('Generation cancelled.')
325 except Exception as e:
326 print(e)
327 return
334 328
335 def do_exit(self, line): 329 def do_exit(self, line):
336 return True 330 return True
@@ -345,7 +339,7 @@ def main():
345 output_dir = Path(args.output_dir) 339 output_dir = Path(args.output_dir)
346 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 340 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
347 341
348 pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) 342 pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype)
349 cmd_parser = create_cmd_parser() 343 cmd_parser = create_cmd_parser()
350 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 344 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
351 cmd_prompt.cmdloop() 345 cmd_prompt.cmdloop()