summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
commit8364ce697ddf6117fdd4f7222832d546d63880de (patch)
tree152c99815bbd8b2659d0dabe63c98f63151c97c2 /infer.py
parentFix LoRA training with DAdan (diff)
downloadtextual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip
Update
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py124
1 files changed, 83 insertions, 41 deletions
diff --git a/infer.py b/infer.py
index 7346de9..3b3b595 100644
--- a/infer.py
+++ b/infer.py
@@ -24,7 +24,7 @@ from diffusers import (
24 KDPM2DiscreteScheduler, 24 KDPM2DiscreteScheduler,
25 KDPM2AncestralDiscreteScheduler, 25 KDPM2AncestralDiscreteScheduler,
26 DEISMultistepScheduler, 26 DEISMultistepScheduler,
27 UniPCMultistepScheduler 27 UniPCMultistepScheduler,
28) 28)
29from peft import LoraConfig, LoraModel, set_peft_model_state_dict 29from peft import LoraConfig, LoraModel, set_peft_model_state_dict
30from safetensors.torch import load_file 30from safetensors.torch import load_file
@@ -61,7 +61,7 @@ default_cmds = {
61 "negative_prompt": None, 61 "negative_prompt": None,
62 "shuffle": False, 62 "shuffle": False,
63 "image": None, 63 "image": None,
64 "image_noise": .7, 64 "image_noise": 0.7,
65 "width": 768, 65 "width": 768,
66 "height": 768, 66 "height": 768,
67 "batch_size": 1, 67 "batch_size": 1,
@@ -69,7 +69,6 @@ default_cmds = {
69 "steps": 30, 69 "steps": 30,
70 "guidance_scale": 7.0, 70 "guidance_scale": 7.0,
71 "sag_scale": 0, 71 "sag_scale": 0,
72 "brightness_offset": 0,
73 "seed": None, 72 "seed": None,
74 "config": None, 73 "config": None,
75} 74}
@@ -85,9 +84,7 @@ def merge_dicts(d1, *args):
85 84
86 85
87def create_args_parser(): 86def create_args_parser():
88 parser = argparse.ArgumentParser( 87 parser = argparse.ArgumentParser(description="Simple example of a training script.")
89 description="Simple example of a training script."
90 )
91 parser.add_argument( 88 parser.add_argument(
92 "--model", 89 "--model",
93 type=str, 90 type=str,
@@ -118,9 +115,7 @@ def create_args_parser():
118 115
119 116
120def create_cmd_parser(): 117def create_cmd_parser():
121 parser = argparse.ArgumentParser( 118 parser = argparse.ArgumentParser(description="Simple example of a training script.")
122 description="Simple example of a training script."
123 )
124 parser.add_argument( 119 parser.add_argument(
125 "--project", 120 "--project",
126 type=str, 121 type=str,
@@ -130,13 +125,34 @@ def create_cmd_parser():
130 parser.add_argument( 125 parser.add_argument(
131 "--scheduler", 126 "--scheduler",
132 type=str, 127 type=str,
133 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", "unipc"], 128 choices=[
129 "plms",
130 "ddim",
131 "klms",
132 "dpmsm",
133 "dpmss",
134 "euler_a",
135 "kdpm2",
136 "kdpm2_a",
137 "deis",
138 "unipc",
139 ],
134 ) 140 )
135 parser.add_argument( 141 parser.add_argument(
136 "--subscheduler", 142 "--subscheduler",
137 type=str, 143 type=str,
138 default=None, 144 default=None,
139 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis"], 145 choices=[
146 "plms",
147 "ddim",
148 "klms",
149 "dpmsm",
150 "dpmss",
151 "euler_a",
152 "kdpm2",
153 "kdpm2_a",
154 "deis",
155 ],
140 ) 156 )
141 parser.add_argument( 157 parser.add_argument(
142 "--template", 158 "--template",
@@ -193,10 +209,6 @@ def create_cmd_parser():
193 type=float, 209 type=float,
194 ) 210 )
195 parser.add_argument( 211 parser.add_argument(
196 "--brightness_offset",
197 type=float,
198 )
199 parser.add_argument(
200 "--seed", 212 "--seed",
201 type=int, 213 type=int,
202 ) 214 )
@@ -214,7 +226,9 @@ def run_parser(parser, defaults, input=None):
214 226
215 if args.config is not None: 227 if args.config is not None:
216 conf_args = load_config(args.config) 228 conf_args = load_config(args.config)
217 conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] 229 conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[
230 0
231 ]
218 232
219 res = defaults.copy() 233 res = defaults.copy()
220 for dict in [vars(conf_args), vars(args)]: 234 for dict in [vars(conf_args), vars(args)]:
@@ -234,10 +248,12 @@ def load_embeddings_dir(pipeline, embeddings_dir):
234 added_tokens, added_ids = load_embeddings_from_dir( 248 added_tokens, added_ids = load_embeddings_from_dir(
235 pipeline.tokenizer, 249 pipeline.tokenizer,
236 pipeline.text_encoder.text_model.embeddings, 250 pipeline.text_encoder.text_model.embeddings,
237 Path(embeddings_dir) 251 Path(embeddings_dir),
238 ) 252 )
239 pipeline.text_encoder.text_model.embeddings.persist() 253 pipeline.text_encoder.text_model.embeddings.persist()
240 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 254 print(
255 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
256 )
241 257
242 258
243def load_lora(pipeline, path): 259def load_lora(pipeline, path):
@@ -255,9 +271,13 @@ def load_lora(pipeline, path):
255 return 271 return
256 272
257 lora_checkpoint_sd = load_file(path / tensor_files[0]) 273 lora_checkpoint_sd = load_file(path / tensor_files[0])
258 unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} 274 unet_lora_ds = {
275 k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k
276 }
259 text_encoder_lora_ds = { 277 text_encoder_lora_ds = {
260 k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k 278 k.replace("text_encoder_", ""): v
279 for k, v in lora_checkpoint_sd.items()
280 if "text_encoder_" in k
261 } 281 }
262 ti_lora_ds = { 282 ti_lora_ds = {
263 k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k 283 k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k
@@ -282,7 +302,9 @@ def load_lora(pipeline, path):
282 token_embeddings=token_embeddings, 302 token_embeddings=token_embeddings,
283 ) 303 )
284 pipeline.text_encoder.text_model.embeddings.persist() 304 pipeline.text_encoder.text_model.embeddings.persist()
285 print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") 305 print(
306 f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}"
307 )
286 308
287 return 309 return
288 310
@@ -315,17 +337,25 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None)
315 solver_p=create_scheduler(config, subscheduler), 337 solver_p=create_scheduler(config, subscheduler),
316 ) 338 )
317 else: 339 else:
318 raise ValueError(f"Unknown scheduler \"{scheduler}\"") 340 raise ValueError(f'Unknown scheduler "{scheduler}"')
319 341
320 342
321def create_pipeline(model, dtype): 343def create_pipeline(model, dtype):
322 print("Loading Stable Diffusion pipeline...") 344 print("Loading Stable Diffusion pipeline...")
323 345
324 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 346 tokenizer = MultiCLIPTokenizer.from_pretrained(
325 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) 347 model, subfolder="tokenizer", torch_dtype=dtype
326 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 348 )
327 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 349 text_encoder = CLIPTextModel.from_pretrained(
328 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 350 model, subfolder="text_encoder", torch_dtype=dtype
351 )
352 vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype)
353 unet = UNet2DConditionModel.from_pretrained(
354 model, subfolder="unet", torch_dtype=dtype
355 )
356 scheduler = DDIMScheduler.from_pretrained(
357 model, subfolder="scheduler", torch_dtype=dtype
358 )
329 359
330 patch_managed_embeddings(text_encoder) 360 patch_managed_embeddings(text_encoder)
331 361
@@ -347,7 +377,9 @@ def create_pipeline(model, dtype):
347 377
348 378
349def shuffle_prompts(prompts: list[str]) -> list[str]: 379def shuffle_prompts(prompts: list[str]) -> list[str]:
350 return [keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts] 380 return [
381 keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts
382 ]
351 383
352 384
353@torch.inference_mode() 385@torch.inference_mode()
@@ -386,12 +418,13 @@ def generate(output_dir: Path, pipeline, args):
386 else: 418 else:
387 init_image = None 419 init_image = None
388 420
389 pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler) 421 pipeline.scheduler = create_scheduler(
422 pipeline.scheduler.config, args.scheduler, args.subscheduler
423 )
390 424
391 for i in range(args.batch_num): 425 for i in range(args.batch_num):
392 pipeline.set_progress_bar_config( 426 pipeline.set_progress_bar_config(
393 desc=f"Batch {i + 1} of {args.batch_num}", 427 desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True
394 dynamic_ncols=True
395 ) 428 )
396 429
397 seed = args.seed + i 430 seed = args.seed + i
@@ -409,7 +442,6 @@ def generate(output_dir: Path, pipeline, args):
409 generator=generator, 442 generator=generator,
410 image=init_image, 443 image=init_image,
411 strength=args.image_noise, 444 strength=args.image_noise,
412 brightness_offset=args.brightness_offset,
413 ).images 445 ).images
414 446
415 for j, image in enumerate(images): 447 for j, image in enumerate(images):
@@ -418,7 +450,7 @@ def generate(output_dir: Path, pipeline, args):
418 450
419 image.save(dir / f"{basename}.png") 451 image.save(dir / f"{basename}.png")
420 image.save(dir / f"{basename}.jpg", quality=85) 452 image.save(dir / f"{basename}.jpg", quality=85)
421 with open(dir / f"{basename}.txt", 'w') as f: 453 with open(dir / f"{basename}.txt", "w") as f:
422 f.write(prompt[j % len(args.prompt)]) 454 f.write(prompt[j % len(args.prompt)])
423 455
424 if torch.cuda.is_available(): 456 if torch.cuda.is_available():
@@ -426,10 +458,12 @@ def generate(output_dir: Path, pipeline, args):
426 458
427 459
428class CmdParse(cmd.Cmd): 460class CmdParse(cmd.Cmd):
429 prompt = 'dream> ' 461 prompt = "dream> "
430 commands = [] 462 commands = []
431 463
432 def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): 464 def __init__(
465 self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser
466 ):
433 super().__init__() 467 super().__init__()
434 468
435 self.output_dir = output_dir 469 self.output_dir = output_dir
@@ -447,10 +481,10 @@ class CmdParse(cmd.Cmd):
447 print(str(e)) 481 print(str(e))
448 return 482 return
449 483
450 if elements[0] == 'q': 484 if elements[0] == "q":
451 return True 485 return True
452 486
453 if elements[0] == 'reload_embeddings': 487 if elements[0] == "reload_embeddings":
454 load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) 488 load_embeddings_dir(self.pipeline, self.ti_embeddings_dir)
455 return 489 return
456 490
@@ -458,7 +492,7 @@ class CmdParse(cmd.Cmd):
458 args = run_parser(self.parser, default_cmds, elements) 492 args = run_parser(self.parser, default_cmds, elements)
459 493
460 if len(args.prompt) == 0: 494 if len(args.prompt) == 0:
461 print('Try again with a prompt!') 495 print("Try again with a prompt!")
462 return 496 return
463 except SystemExit: 497 except SystemExit:
464 traceback.print_exc() 498 traceback.print_exc()
@@ -471,7 +505,7 @@ class CmdParse(cmd.Cmd):
471 try: 505 try:
472 generate(self.output_dir, self.pipeline, args) 506 generate(self.output_dir, self.pipeline, args)
473 except KeyboardInterrupt: 507 except KeyboardInterrupt:
474 print('Generation cancelled.') 508 print("Generation cancelled.")
475 except Exception as e: 509 except Exception as e:
476 traceback.print_exc() 510 traceback.print_exc()
477 return 511 return
@@ -487,7 +521,9 @@ def main():
487 args = run_parser(args_parser, default_args) 521 args = run_parser(args_parser, default_args)
488 522
489 output_dir = Path(args.output_dir) 523 output_dir = Path(args.output_dir)
490 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 524 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[
525 args.precision
526 ]
491 527
492 pipeline = create_pipeline(args.model, dtype) 528 pipeline = create_pipeline(args.model, dtype)
493 529
@@ -496,7 +532,13 @@ def main():
496 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) 532 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
497 533
498 cmd_parser = create_cmd_parser() 534 cmd_parser = create_cmd_parser()
499 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) 535 cmd_prompt = CmdParse(
536 output_dir,
537 args.ti_embeddings_dir,
538 args.lora_embeddings_dir,
539 pipeline,
540 cmd_parser,
541 )
500 cmd_prompt.cmdloop() 542 cmd_prompt.cmdloop()
501 543
502 544