diff options
author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /infer.py | |
parent | Fix LoRA training with DAdan (diff) | |
download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip |
Update
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 124 |
1 files changed, 83 insertions, 41 deletions
@@ -24,7 +24,7 @@ from diffusers import ( | |||
24 | KDPM2DiscreteScheduler, | 24 | KDPM2DiscreteScheduler, |
25 | KDPM2AncestralDiscreteScheduler, | 25 | KDPM2AncestralDiscreteScheduler, |
26 | DEISMultistepScheduler, | 26 | DEISMultistepScheduler, |
27 | UniPCMultistepScheduler | 27 | UniPCMultistepScheduler, |
28 | ) | 28 | ) |
29 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict | 29 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict |
30 | from safetensors.torch import load_file | 30 | from safetensors.torch import load_file |
@@ -61,7 +61,7 @@ default_cmds = { | |||
61 | "negative_prompt": None, | 61 | "negative_prompt": None, |
62 | "shuffle": False, | 62 | "shuffle": False, |
63 | "image": None, | 63 | "image": None, |
64 | "image_noise": .7, | 64 | "image_noise": 0.7, |
65 | "width": 768, | 65 | "width": 768, |
66 | "height": 768, | 66 | "height": 768, |
67 | "batch_size": 1, | 67 | "batch_size": 1, |
@@ -69,7 +69,6 @@ default_cmds = { | |||
69 | "steps": 30, | 69 | "steps": 30, |
70 | "guidance_scale": 7.0, | 70 | "guidance_scale": 7.0, |
71 | "sag_scale": 0, | 71 | "sag_scale": 0, |
72 | "brightness_offset": 0, | ||
73 | "seed": None, | 72 | "seed": None, |
74 | "config": None, | 73 | "config": None, |
75 | } | 74 | } |
@@ -85,9 +84,7 @@ def merge_dicts(d1, *args): | |||
85 | 84 | ||
86 | 85 | ||
87 | def create_args_parser(): | 86 | def create_args_parser(): |
88 | parser = argparse.ArgumentParser( | 87 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
89 | description="Simple example of a training script." | ||
90 | ) | ||
91 | parser.add_argument( | 88 | parser.add_argument( |
92 | "--model", | 89 | "--model", |
93 | type=str, | 90 | type=str, |
@@ -118,9 +115,7 @@ def create_args_parser(): | |||
118 | 115 | ||
119 | 116 | ||
120 | def create_cmd_parser(): | 117 | def create_cmd_parser(): |
121 | parser = argparse.ArgumentParser( | 118 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
122 | description="Simple example of a training script." | ||
123 | ) | ||
124 | parser.add_argument( | 119 | parser.add_argument( |
125 | "--project", | 120 | "--project", |
126 | type=str, | 121 | type=str, |
@@ -130,13 +125,34 @@ def create_cmd_parser(): | |||
130 | parser.add_argument( | 125 | parser.add_argument( |
131 | "--scheduler", | 126 | "--scheduler", |
132 | type=str, | 127 | type=str, |
133 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", "unipc"], | 128 | choices=[ |
129 | "plms", | ||
130 | "ddim", | ||
131 | "klms", | ||
132 | "dpmsm", | ||
133 | "dpmss", | ||
134 | "euler_a", | ||
135 | "kdpm2", | ||
136 | "kdpm2_a", | ||
137 | "deis", | ||
138 | "unipc", | ||
139 | ], | ||
134 | ) | 140 | ) |
135 | parser.add_argument( | 141 | parser.add_argument( |
136 | "--subscheduler", | 142 | "--subscheduler", |
137 | type=str, | 143 | type=str, |
138 | default=None, | 144 | default=None, |
139 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis"], | 145 | choices=[ |
146 | "plms", | ||
147 | "ddim", | ||
148 | "klms", | ||
149 | "dpmsm", | ||
150 | "dpmss", | ||
151 | "euler_a", | ||
152 | "kdpm2", | ||
153 | "kdpm2_a", | ||
154 | "deis", | ||
155 | ], | ||
140 | ) | 156 | ) |
141 | parser.add_argument( | 157 | parser.add_argument( |
142 | "--template", | 158 | "--template", |
@@ -193,10 +209,6 @@ def create_cmd_parser(): | |||
193 | type=float, | 209 | type=float, |
194 | ) | 210 | ) |
195 | parser.add_argument( | 211 | parser.add_argument( |
196 | "--brightness_offset", | ||
197 | type=float, | ||
198 | ) | ||
199 | parser.add_argument( | ||
200 | "--seed", | 212 | "--seed", |
201 | type=int, | 213 | type=int, |
202 | ) | 214 | ) |
@@ -214,7 +226,9 @@ def run_parser(parser, defaults, input=None): | |||
214 | 226 | ||
215 | if args.config is not None: | 227 | if args.config is not None: |
216 | conf_args = load_config(args.config) | 228 | conf_args = load_config(args.config) |
217 | conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] | 229 | conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[ |
230 | 0 | ||
231 | ] | ||
218 | 232 | ||
219 | res = defaults.copy() | 233 | res = defaults.copy() |
220 | for dict in [vars(conf_args), vars(args)]: | 234 | for dict in [vars(conf_args), vars(args)]: |
@@ -234,10 +248,12 @@ def load_embeddings_dir(pipeline, embeddings_dir): | |||
234 | added_tokens, added_ids = load_embeddings_from_dir( | 248 | added_tokens, added_ids = load_embeddings_from_dir( |
235 | pipeline.tokenizer, | 249 | pipeline.tokenizer, |
236 | pipeline.text_encoder.text_model.embeddings, | 250 | pipeline.text_encoder.text_model.embeddings, |
237 | Path(embeddings_dir) | 251 | Path(embeddings_dir), |
238 | ) | 252 | ) |
239 | pipeline.text_encoder.text_model.embeddings.persist() | 253 | pipeline.text_encoder.text_model.embeddings.persist() |
240 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 254 | print( |
255 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
256 | ) | ||
241 | 257 | ||
242 | 258 | ||
243 | def load_lora(pipeline, path): | 259 | def load_lora(pipeline, path): |
@@ -255,9 +271,13 @@ def load_lora(pipeline, path): | |||
255 | return | 271 | return |
256 | 272 | ||
257 | lora_checkpoint_sd = load_file(path / tensor_files[0]) | 273 | lora_checkpoint_sd = load_file(path / tensor_files[0]) |
258 | unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} | 274 | unet_lora_ds = { |
275 | k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k | ||
276 | } | ||
259 | text_encoder_lora_ds = { | 277 | text_encoder_lora_ds = { |
260 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | 278 | k.replace("text_encoder_", ""): v |
279 | for k, v in lora_checkpoint_sd.items() | ||
280 | if "text_encoder_" in k | ||
261 | } | 281 | } |
262 | ti_lora_ds = { | 282 | ti_lora_ds = { |
263 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k | 283 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k |
@@ -282,7 +302,9 @@ def load_lora(pipeline, path): | |||
282 | token_embeddings=token_embeddings, | 302 | token_embeddings=token_embeddings, |
283 | ) | 303 | ) |
284 | pipeline.text_encoder.text_model.embeddings.persist() | 304 | pipeline.text_encoder.text_model.embeddings.persist() |
285 | print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") | 305 | print( |
306 | f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}" | ||
307 | ) | ||
286 | 308 | ||
287 | return | 309 | return |
288 | 310 | ||
@@ -315,17 +337,25 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) | |||
315 | solver_p=create_scheduler(config, subscheduler), | 337 | solver_p=create_scheduler(config, subscheduler), |
316 | ) | 338 | ) |
317 | else: | 339 | else: |
318 | raise ValueError(f"Unknown scheduler \"{scheduler}\"") | 340 | raise ValueError(f'Unknown scheduler "{scheduler}"') |
319 | 341 | ||
320 | 342 | ||
321 | def create_pipeline(model, dtype): | 343 | def create_pipeline(model, dtype): |
322 | print("Loading Stable Diffusion pipeline...") | 344 | print("Loading Stable Diffusion pipeline...") |
323 | 345 | ||
324 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 346 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
325 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 347 | model, subfolder="tokenizer", torch_dtype=dtype |
326 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 348 | ) |
327 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 349 | text_encoder = CLIPTextModel.from_pretrained( |
328 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 350 | model, subfolder="text_encoder", torch_dtype=dtype |
351 | ) | ||
352 | vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype) | ||
353 | unet = UNet2DConditionModel.from_pretrained( | ||
354 | model, subfolder="unet", torch_dtype=dtype | ||
355 | ) | ||
356 | scheduler = DDIMScheduler.from_pretrained( | ||
357 | model, subfolder="scheduler", torch_dtype=dtype | ||
358 | ) | ||
329 | 359 | ||
330 | patch_managed_embeddings(text_encoder) | 360 | patch_managed_embeddings(text_encoder) |
331 | 361 | ||
@@ -347,7 +377,9 @@ def create_pipeline(model, dtype): | |||
347 | 377 | ||
348 | 378 | ||
349 | def shuffle_prompts(prompts: list[str]) -> list[str]: | 379 | def shuffle_prompts(prompts: list[str]) -> list[str]: |
350 | return [keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts] | 380 | return [ |
381 | keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts | ||
382 | ] | ||
351 | 383 | ||
352 | 384 | ||
353 | @torch.inference_mode() | 385 | @torch.inference_mode() |
@@ -386,12 +418,13 @@ def generate(output_dir: Path, pipeline, args): | |||
386 | else: | 418 | else: |
387 | init_image = None | 419 | init_image = None |
388 | 420 | ||
389 | pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler) | 421 | pipeline.scheduler = create_scheduler( |
422 | pipeline.scheduler.config, args.scheduler, args.subscheduler | ||
423 | ) | ||
390 | 424 | ||
391 | for i in range(args.batch_num): | 425 | for i in range(args.batch_num): |
392 | pipeline.set_progress_bar_config( | 426 | pipeline.set_progress_bar_config( |
393 | desc=f"Batch {i + 1} of {args.batch_num}", | 427 | desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True |
394 | dynamic_ncols=True | ||
395 | ) | 428 | ) |
396 | 429 | ||
397 | seed = args.seed + i | 430 | seed = args.seed + i |
@@ -409,7 +442,6 @@ def generate(output_dir: Path, pipeline, args): | |||
409 | generator=generator, | 442 | generator=generator, |
410 | image=init_image, | 443 | image=init_image, |
411 | strength=args.image_noise, | 444 | strength=args.image_noise, |
412 | brightness_offset=args.brightness_offset, | ||
413 | ).images | 445 | ).images |
414 | 446 | ||
415 | for j, image in enumerate(images): | 447 | for j, image in enumerate(images): |
@@ -418,7 +450,7 @@ def generate(output_dir: Path, pipeline, args): | |||
418 | 450 | ||
419 | image.save(dir / f"{basename}.png") | 451 | image.save(dir / f"{basename}.png") |
420 | image.save(dir / f"{basename}.jpg", quality=85) | 452 | image.save(dir / f"{basename}.jpg", quality=85) |
421 | with open(dir / f"{basename}.txt", 'w') as f: | 453 | with open(dir / f"{basename}.txt", "w") as f: |
422 | f.write(prompt[j % len(args.prompt)]) | 454 | f.write(prompt[j % len(args.prompt)]) |
423 | 455 | ||
424 | if torch.cuda.is_available(): | 456 | if torch.cuda.is_available(): |
@@ -426,10 +458,12 @@ def generate(output_dir: Path, pipeline, args): | |||
426 | 458 | ||
427 | 459 | ||
428 | class CmdParse(cmd.Cmd): | 460 | class CmdParse(cmd.Cmd): |
429 | prompt = 'dream> ' | 461 | prompt = "dream> " |
430 | commands = [] | 462 | commands = [] |
431 | 463 | ||
432 | def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): | 464 | def __init__( |
465 | self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser | ||
466 | ): | ||
433 | super().__init__() | 467 | super().__init__() |
434 | 468 | ||
435 | self.output_dir = output_dir | 469 | self.output_dir = output_dir |
@@ -447,10 +481,10 @@ class CmdParse(cmd.Cmd): | |||
447 | print(str(e)) | 481 | print(str(e)) |
448 | return | 482 | return |
449 | 483 | ||
450 | if elements[0] == 'q': | 484 | if elements[0] == "q": |
451 | return True | 485 | return True |
452 | 486 | ||
453 | if elements[0] == 'reload_embeddings': | 487 | if elements[0] == "reload_embeddings": |
454 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) | 488 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) |
455 | return | 489 | return |
456 | 490 | ||
@@ -458,7 +492,7 @@ class CmdParse(cmd.Cmd): | |||
458 | args = run_parser(self.parser, default_cmds, elements) | 492 | args = run_parser(self.parser, default_cmds, elements) |
459 | 493 | ||
460 | if len(args.prompt) == 0: | 494 | if len(args.prompt) == 0: |
461 | print('Try again with a prompt!') | 495 | print("Try again with a prompt!") |
462 | return | 496 | return |
463 | except SystemExit: | 497 | except SystemExit: |
464 | traceback.print_exc() | 498 | traceback.print_exc() |
@@ -471,7 +505,7 @@ class CmdParse(cmd.Cmd): | |||
471 | try: | 505 | try: |
472 | generate(self.output_dir, self.pipeline, args) | 506 | generate(self.output_dir, self.pipeline, args) |
473 | except KeyboardInterrupt: | 507 | except KeyboardInterrupt: |
474 | print('Generation cancelled.') | 508 | print("Generation cancelled.") |
475 | except Exception as e: | 509 | except Exception as e: |
476 | traceback.print_exc() | 510 | traceback.print_exc() |
477 | return | 511 | return |
@@ -487,7 +521,9 @@ def main(): | |||
487 | args = run_parser(args_parser, default_args) | 521 | args = run_parser(args_parser, default_args) |
488 | 522 | ||
489 | output_dir = Path(args.output_dir) | 523 | output_dir = Path(args.output_dir) |
490 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 524 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[ |
525 | args.precision | ||
526 | ] | ||
491 | 527 | ||
492 | pipeline = create_pipeline(args.model, dtype) | 528 | pipeline = create_pipeline(args.model, dtype) |
493 | 529 | ||
@@ -496,7 +532,13 @@ def main(): | |||
496 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 532 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
497 | 533 | ||
498 | cmd_parser = create_cmd_parser() | 534 | cmd_parser = create_cmd_parser() |
499 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) | 535 | cmd_prompt = CmdParse( |
536 | output_dir, | ||
537 | args.ti_embeddings_dir, | ||
538 | args.lora_embeddings_dir, | ||
539 | pipeline, | ||
540 | cmd_parser, | ||
541 | ) | ||
500 | cmd_prompt.cmdloop() | 542 | cmd_prompt.cmdloop() |
501 | 543 | ||
502 | 544 | ||