diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
| commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
| tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /infer.py | |
| parent | Fix LoRA training with DAdan (diff) | |
| download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip | |
Update
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 124 |
1 files changed, 83 insertions, 41 deletions
| @@ -24,7 +24,7 @@ from diffusers import ( | |||
| 24 | KDPM2DiscreteScheduler, | 24 | KDPM2DiscreteScheduler, |
| 25 | KDPM2AncestralDiscreteScheduler, | 25 | KDPM2AncestralDiscreteScheduler, |
| 26 | DEISMultistepScheduler, | 26 | DEISMultistepScheduler, |
| 27 | UniPCMultistepScheduler | 27 | UniPCMultistepScheduler, |
| 28 | ) | 28 | ) |
| 29 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict | 29 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict |
| 30 | from safetensors.torch import load_file | 30 | from safetensors.torch import load_file |
| @@ -61,7 +61,7 @@ default_cmds = { | |||
| 61 | "negative_prompt": None, | 61 | "negative_prompt": None, |
| 62 | "shuffle": False, | 62 | "shuffle": False, |
| 63 | "image": None, | 63 | "image": None, |
| 64 | "image_noise": .7, | 64 | "image_noise": 0.7, |
| 65 | "width": 768, | 65 | "width": 768, |
| 66 | "height": 768, | 66 | "height": 768, |
| 67 | "batch_size": 1, | 67 | "batch_size": 1, |
| @@ -69,7 +69,6 @@ default_cmds = { | |||
| 69 | "steps": 30, | 69 | "steps": 30, |
| 70 | "guidance_scale": 7.0, | 70 | "guidance_scale": 7.0, |
| 71 | "sag_scale": 0, | 71 | "sag_scale": 0, |
| 72 | "brightness_offset": 0, | ||
| 73 | "seed": None, | 72 | "seed": None, |
| 74 | "config": None, | 73 | "config": None, |
| 75 | } | 74 | } |
| @@ -85,9 +84,7 @@ def merge_dicts(d1, *args): | |||
| 85 | 84 | ||
| 86 | 85 | ||
| 87 | def create_args_parser(): | 86 | def create_args_parser(): |
| 88 | parser = argparse.ArgumentParser( | 87 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 89 | description="Simple example of a training script." | ||
| 90 | ) | ||
| 91 | parser.add_argument( | 88 | parser.add_argument( |
| 92 | "--model", | 89 | "--model", |
| 93 | type=str, | 90 | type=str, |
| @@ -118,9 +115,7 @@ def create_args_parser(): | |||
| 118 | 115 | ||
| 119 | 116 | ||
| 120 | def create_cmd_parser(): | 117 | def create_cmd_parser(): |
| 121 | parser = argparse.ArgumentParser( | 118 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 122 | description="Simple example of a training script." | ||
| 123 | ) | ||
| 124 | parser.add_argument( | 119 | parser.add_argument( |
| 125 | "--project", | 120 | "--project", |
| 126 | type=str, | 121 | type=str, |
| @@ -130,13 +125,34 @@ def create_cmd_parser(): | |||
| 130 | parser.add_argument( | 125 | parser.add_argument( |
| 131 | "--scheduler", | 126 | "--scheduler", |
| 132 | type=str, | 127 | type=str, |
| 133 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", "unipc"], | 128 | choices=[ |
| 129 | "plms", | ||
| 130 | "ddim", | ||
| 131 | "klms", | ||
| 132 | "dpmsm", | ||
| 133 | "dpmss", | ||
| 134 | "euler_a", | ||
| 135 | "kdpm2", | ||
| 136 | "kdpm2_a", | ||
| 137 | "deis", | ||
| 138 | "unipc", | ||
| 139 | ], | ||
| 134 | ) | 140 | ) |
| 135 | parser.add_argument( | 141 | parser.add_argument( |
| 136 | "--subscheduler", | 142 | "--subscheduler", |
| 137 | type=str, | 143 | type=str, |
| 138 | default=None, | 144 | default=None, |
| 139 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis"], | 145 | choices=[ |
| 146 | "plms", | ||
| 147 | "ddim", | ||
| 148 | "klms", | ||
| 149 | "dpmsm", | ||
| 150 | "dpmss", | ||
| 151 | "euler_a", | ||
| 152 | "kdpm2", | ||
| 153 | "kdpm2_a", | ||
| 154 | "deis", | ||
| 155 | ], | ||
| 140 | ) | 156 | ) |
| 141 | parser.add_argument( | 157 | parser.add_argument( |
| 142 | "--template", | 158 | "--template", |
| @@ -193,10 +209,6 @@ def create_cmd_parser(): | |||
| 193 | type=float, | 209 | type=float, |
| 194 | ) | 210 | ) |
| 195 | parser.add_argument( | 211 | parser.add_argument( |
| 196 | "--brightness_offset", | ||
| 197 | type=float, | ||
| 198 | ) | ||
| 199 | parser.add_argument( | ||
| 200 | "--seed", | 212 | "--seed", |
| 201 | type=int, | 213 | type=int, |
| 202 | ) | 214 | ) |
| @@ -214,7 +226,9 @@ def run_parser(parser, defaults, input=None): | |||
| 214 | 226 | ||
| 215 | if args.config is not None: | 227 | if args.config is not None: |
| 216 | conf_args = load_config(args.config) | 228 | conf_args = load_config(args.config) |
| 217 | conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] | 229 | conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[ |
| 230 | 0 | ||
| 231 | ] | ||
| 218 | 232 | ||
| 219 | res = defaults.copy() | 233 | res = defaults.copy() |
| 220 | for dict in [vars(conf_args), vars(args)]: | 234 | for dict in [vars(conf_args), vars(args)]: |
| @@ -234,10 +248,12 @@ def load_embeddings_dir(pipeline, embeddings_dir): | |||
| 234 | added_tokens, added_ids = load_embeddings_from_dir( | 248 | added_tokens, added_ids = load_embeddings_from_dir( |
| 235 | pipeline.tokenizer, | 249 | pipeline.tokenizer, |
| 236 | pipeline.text_encoder.text_model.embeddings, | 250 | pipeline.text_encoder.text_model.embeddings, |
| 237 | Path(embeddings_dir) | 251 | Path(embeddings_dir), |
| 238 | ) | 252 | ) |
| 239 | pipeline.text_encoder.text_model.embeddings.persist() | 253 | pipeline.text_encoder.text_model.embeddings.persist() |
| 240 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 254 | print( |
| 255 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
| 256 | ) | ||
| 241 | 257 | ||
| 242 | 258 | ||
| 243 | def load_lora(pipeline, path): | 259 | def load_lora(pipeline, path): |
| @@ -255,9 +271,13 @@ def load_lora(pipeline, path): | |||
| 255 | return | 271 | return |
| 256 | 272 | ||
| 257 | lora_checkpoint_sd = load_file(path / tensor_files[0]) | 273 | lora_checkpoint_sd = load_file(path / tensor_files[0]) |
| 258 | unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} | 274 | unet_lora_ds = { |
| 275 | k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k | ||
| 276 | } | ||
| 259 | text_encoder_lora_ds = { | 277 | text_encoder_lora_ds = { |
| 260 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | 278 | k.replace("text_encoder_", ""): v |
| 279 | for k, v in lora_checkpoint_sd.items() | ||
| 280 | if "text_encoder_" in k | ||
| 261 | } | 281 | } |
| 262 | ti_lora_ds = { | 282 | ti_lora_ds = { |
| 263 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k | 283 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k |
| @@ -282,7 +302,9 @@ def load_lora(pipeline, path): | |||
| 282 | token_embeddings=token_embeddings, | 302 | token_embeddings=token_embeddings, |
| 283 | ) | 303 | ) |
| 284 | pipeline.text_encoder.text_model.embeddings.persist() | 304 | pipeline.text_encoder.text_model.embeddings.persist() |
| 285 | print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") | 305 | print( |
| 306 | f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}" | ||
| 307 | ) | ||
| 286 | 308 | ||
| 287 | return | 309 | return |
| 288 | 310 | ||
| @@ -315,17 +337,25 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) | |||
| 315 | solver_p=create_scheduler(config, subscheduler), | 337 | solver_p=create_scheduler(config, subscheduler), |
| 316 | ) | 338 | ) |
| 317 | else: | 339 | else: |
| 318 | raise ValueError(f"Unknown scheduler \"{scheduler}\"") | 340 | raise ValueError(f'Unknown scheduler "{scheduler}"') |
| 319 | 341 | ||
| 320 | 342 | ||
| 321 | def create_pipeline(model, dtype): | 343 | def create_pipeline(model, dtype): |
| 322 | print("Loading Stable Diffusion pipeline...") | 344 | print("Loading Stable Diffusion pipeline...") |
| 323 | 345 | ||
| 324 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 346 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
| 325 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 347 | model, subfolder="tokenizer", torch_dtype=dtype |
| 326 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 348 | ) |
| 327 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 349 | text_encoder = CLIPTextModel.from_pretrained( |
| 328 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 350 | model, subfolder="text_encoder", torch_dtype=dtype |
| 351 | ) | ||
| 352 | vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype) | ||
| 353 | unet = UNet2DConditionModel.from_pretrained( | ||
| 354 | model, subfolder="unet", torch_dtype=dtype | ||
| 355 | ) | ||
| 356 | scheduler = DDIMScheduler.from_pretrained( | ||
| 357 | model, subfolder="scheduler", torch_dtype=dtype | ||
| 358 | ) | ||
| 329 | 359 | ||
| 330 | patch_managed_embeddings(text_encoder) | 360 | patch_managed_embeddings(text_encoder) |
| 331 | 361 | ||
| @@ -347,7 +377,9 @@ def create_pipeline(model, dtype): | |||
| 347 | 377 | ||
| 348 | 378 | ||
| 349 | def shuffle_prompts(prompts: list[str]) -> list[str]: | 379 | def shuffle_prompts(prompts: list[str]) -> list[str]: |
| 350 | return [keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts] | 380 | return [ |
| 381 | keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts | ||
| 382 | ] | ||
| 351 | 383 | ||
| 352 | 384 | ||
| 353 | @torch.inference_mode() | 385 | @torch.inference_mode() |
| @@ -386,12 +418,13 @@ def generate(output_dir: Path, pipeline, args): | |||
| 386 | else: | 418 | else: |
| 387 | init_image = None | 419 | init_image = None |
| 388 | 420 | ||
| 389 | pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler) | 421 | pipeline.scheduler = create_scheduler( |
| 422 | pipeline.scheduler.config, args.scheduler, args.subscheduler | ||
| 423 | ) | ||
| 390 | 424 | ||
| 391 | for i in range(args.batch_num): | 425 | for i in range(args.batch_num): |
| 392 | pipeline.set_progress_bar_config( | 426 | pipeline.set_progress_bar_config( |
| 393 | desc=f"Batch {i + 1} of {args.batch_num}", | 427 | desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True |
| 394 | dynamic_ncols=True | ||
| 395 | ) | 428 | ) |
| 396 | 429 | ||
| 397 | seed = args.seed + i | 430 | seed = args.seed + i |
| @@ -409,7 +442,6 @@ def generate(output_dir: Path, pipeline, args): | |||
| 409 | generator=generator, | 442 | generator=generator, |
| 410 | image=init_image, | 443 | image=init_image, |
| 411 | strength=args.image_noise, | 444 | strength=args.image_noise, |
| 412 | brightness_offset=args.brightness_offset, | ||
| 413 | ).images | 445 | ).images |
| 414 | 446 | ||
| 415 | for j, image in enumerate(images): | 447 | for j, image in enumerate(images): |
| @@ -418,7 +450,7 @@ def generate(output_dir: Path, pipeline, args): | |||
| 418 | 450 | ||
| 419 | image.save(dir / f"{basename}.png") | 451 | image.save(dir / f"{basename}.png") |
| 420 | image.save(dir / f"{basename}.jpg", quality=85) | 452 | image.save(dir / f"{basename}.jpg", quality=85) |
| 421 | with open(dir / f"{basename}.txt", 'w') as f: | 453 | with open(dir / f"{basename}.txt", "w") as f: |
| 422 | f.write(prompt[j % len(args.prompt)]) | 454 | f.write(prompt[j % len(args.prompt)]) |
| 423 | 455 | ||
| 424 | if torch.cuda.is_available(): | 456 | if torch.cuda.is_available(): |
| @@ -426,10 +458,12 @@ def generate(output_dir: Path, pipeline, args): | |||
| 426 | 458 | ||
| 427 | 459 | ||
| 428 | class CmdParse(cmd.Cmd): | 460 | class CmdParse(cmd.Cmd): |
| 429 | prompt = 'dream> ' | 461 | prompt = "dream> " |
| 430 | commands = [] | 462 | commands = [] |
| 431 | 463 | ||
| 432 | def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): | 464 | def __init__( |
| 465 | self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser | ||
| 466 | ): | ||
| 433 | super().__init__() | 467 | super().__init__() |
| 434 | 468 | ||
| 435 | self.output_dir = output_dir | 469 | self.output_dir = output_dir |
| @@ -447,10 +481,10 @@ class CmdParse(cmd.Cmd): | |||
| 447 | print(str(e)) | 481 | print(str(e)) |
| 448 | return | 482 | return |
| 449 | 483 | ||
| 450 | if elements[0] == 'q': | 484 | if elements[0] == "q": |
| 451 | return True | 485 | return True |
| 452 | 486 | ||
| 453 | if elements[0] == 'reload_embeddings': | 487 | if elements[0] == "reload_embeddings": |
| 454 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) | 488 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) |
| 455 | return | 489 | return |
| 456 | 490 | ||
| @@ -458,7 +492,7 @@ class CmdParse(cmd.Cmd): | |||
| 458 | args = run_parser(self.parser, default_cmds, elements) | 492 | args = run_parser(self.parser, default_cmds, elements) |
| 459 | 493 | ||
| 460 | if len(args.prompt) == 0: | 494 | if len(args.prompt) == 0: |
| 461 | print('Try again with a prompt!') | 495 | print("Try again with a prompt!") |
| 462 | return | 496 | return |
| 463 | except SystemExit: | 497 | except SystemExit: |
| 464 | traceback.print_exc() | 498 | traceback.print_exc() |
| @@ -471,7 +505,7 @@ class CmdParse(cmd.Cmd): | |||
| 471 | try: | 505 | try: |
| 472 | generate(self.output_dir, self.pipeline, args) | 506 | generate(self.output_dir, self.pipeline, args) |
| 473 | except KeyboardInterrupt: | 507 | except KeyboardInterrupt: |
| 474 | print('Generation cancelled.') | 508 | print("Generation cancelled.") |
| 475 | except Exception as e: | 509 | except Exception as e: |
| 476 | traceback.print_exc() | 510 | traceback.print_exc() |
| 477 | return | 511 | return |
| @@ -487,7 +521,9 @@ def main(): | |||
| 487 | args = run_parser(args_parser, default_args) | 521 | args = run_parser(args_parser, default_args) |
| 488 | 522 | ||
| 489 | output_dir = Path(args.output_dir) | 523 | output_dir = Path(args.output_dir) |
| 490 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 524 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[ |
| 525 | args.precision | ||
| 526 | ] | ||
| 491 | 527 | ||
| 492 | pipeline = create_pipeline(args.model, dtype) | 528 | pipeline = create_pipeline(args.model, dtype) |
| 493 | 529 | ||
| @@ -496,7 +532,13 @@ def main(): | |||
| 496 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 532 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
| 497 | 533 | ||
| 498 | cmd_parser = create_cmd_parser() | 534 | cmd_parser = create_cmd_parser() |
| 499 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) | 535 | cmd_prompt = CmdParse( |
| 536 | output_dir, | ||
| 537 | args.ti_embeddings_dir, | ||
| 538 | args.lora_embeddings_dir, | ||
| 539 | pipeline, | ||
| 540 | cmd_parser, | ||
| 541 | ) | ||
| 500 | cmd_prompt.cmdloop() | 542 | cmd_prompt.cmdloop() |
| 501 | 543 | ||
| 502 | 544 | ||
