diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-27 17:57:05 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-27 17:57:05 +0200 |
| commit | 0bc909409648a3cae0061c3de2b39e486473ae39 (patch) | |
| tree | 5fdbcd7c56919293963c3c8b53bdb2099834079d /dreambooth.py | |
| parent | Euler_a: Re-introduce generator arg for reproducible output (diff) | |
| download | textual-inversion-diff-0bc909409648a3cae0061c3de2b39e486473ae39.tar.gz textual-inversion-diff-0bc909409648a3cae0061c3de2b39e486473ae39.tar.bz2 textual-inversion-diff-0bc909409648a3cae0061c3de2b39e486473ae39.zip | |
Added CLI arg to set dataloader worker num; improved text encoder handling with Dreambooth
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py index db097e5..e71b7f0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -72,15 +72,23 @@ def parse_args(): | |||
| 72 | "--placeholder_token", | 72 | "--placeholder_token", |
| 73 | type=str, | 73 | type=str, |
| 74 | nargs='*', | 74 | nargs='*', |
| 75 | default=[], | ||
| 75 | help="A token to use as a placeholder for the concept.", | 76 | help="A token to use as a placeholder for the concept.", |
| 76 | ) | 77 | ) |
| 77 | parser.add_argument( | 78 | parser.add_argument( |
| 78 | "--initializer_token", | 79 | "--initializer_token", |
| 79 | type=str, | 80 | type=str, |
| 80 | nargs='*', | 81 | nargs='*', |
| 82 | default=[], | ||
| 81 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
| 82 | ) | 84 | ) |
| 83 | parser.add_argument( | 85 | parser.add_argument( |
| 86 | "--train_text_encoder", | ||
| 87 | action="store_true", | ||
| 88 | default=True, | ||
| 89 | help="Whether to train the whole text encoder." | ||
| 90 | ) | ||
| 91 | parser.add_argument( | ||
| 84 | "--num_class_images", | 92 | "--num_class_images", |
| 85 | type=int, | 93 | type=int, |
| 86 | default=400, | 94 | default=400, |
| @@ -119,6 +127,15 @@ def parse_args(): | |||
| 119 | help="Whether to center crop images before resizing to resolution" | 127 | help="Whether to center crop images before resizing to resolution" |
| 120 | ) | 128 | ) |
| 121 | parser.add_argument( | 129 | parser.add_argument( |
| 130 | "--dataloader_num_workers", | ||
| 131 | type=int, | ||
| 132 | default=0, | ||
| 133 | help=( | ||
| 134 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
| 135 | " process." | ||
| 136 | ), | ||
| 137 | ) | ||
| 138 | parser.add_argument( | ||
| 122 | "--num_train_epochs", | 139 | "--num_train_epochs", |
| 123 | type=int, | 140 | type=int, |
| 124 | default=100 | 141 | default=100 |
| @@ -323,7 +340,7 @@ def parse_args(): | |||
| 323 | args.placeholder_token = [args.placeholder_token] | 340 | args.placeholder_token = [args.placeholder_token] |
| 324 | 341 | ||
| 325 | if len(args.placeholder_token) == 0: | 342 | if len(args.placeholder_token) == 0: |
| 326 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 343 | args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] |
| 327 | 344 | ||
| 328 | if len(args.placeholder_token) != len(args.initializer_token): | 345 | if len(args.placeholder_token) != len(args.initializer_token): |
| 329 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 346 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
| @@ -391,6 +408,9 @@ class Checkpointer: | |||
| 391 | 408 | ||
| 392 | @torch.no_grad() | 409 | @torch.no_grad() |
| 393 | def save_embedding(self, step, postfix): | 410 | def save_embedding(self, step, postfix): |
| 411 | if len(self.placeholder_token) == 0: | ||
| 412 | return | ||
| 413 | |||
| 394 | print("Saving checkpoint for step %d..." % step) | 414 | print("Saving checkpoint for step %d..." % step) |
| 395 | 415 | ||
| 396 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 416 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
| @@ -406,9 +426,6 @@ class Checkpointer: | |||
| 406 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 426 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
| 407 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 427 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
| 408 | 428 | ||
| 409 | del unwrapped | ||
| 410 | del learned_embeds | ||
| 411 | |||
| 412 | @torch.no_grad() | 429 | @torch.no_grad() |
| 413 | def save_model(self): | 430 | def save_model(self): |
| 414 | print("Saving model...") | 431 | print("Saving model...") |
| @@ -575,7 +592,9 @@ def main(): | |||
| 575 | # Freeze text_encoder and vae | 592 | # Freeze text_encoder and vae |
| 576 | freeze_params(vae.parameters()) | 593 | freeze_params(vae.parameters()) |
| 577 | 594 | ||
| 578 | if len(args.initializer_token) != 0: | 595 | if len(args.placeholder_token) != 0: |
| 596 | print(f"Adding text embeddings: {args.placeholder_token}") | ||
| 597 | |||
| 579 | # Convert the initializer_token, placeholder_token to ids | 598 | # Convert the initializer_token, placeholder_token to ids |
| 580 | initializer_token_ids = torch.stack([ | 599 | initializer_token_ids = torch.stack([ |
| 581 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 600 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
| @@ -597,14 +616,19 @@ def main(): | |||
| 597 | 616 | ||
| 598 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 617 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
| 599 | token_embeds[token_id] = embeddings | 618 | token_embeds[token_id] = embeddings |
| 619 | else: | ||
| 620 | placeholder_token_id = [] | ||
| 621 | |||
| 622 | if args.train_text_encoder: | ||
| 623 | print(f"Training entire text encoder.") | ||
| 624 | else: | ||
| 625 | print(f"Training added text embeddings") | ||
| 600 | 626 | ||
| 601 | freeze_params(itertools.chain( | 627 | freeze_params(itertools.chain( |
| 602 | text_encoder.text_model.encoder.parameters(), | 628 | text_encoder.text_model.encoder.parameters(), |
| 603 | text_encoder.text_model.final_layer_norm.parameters(), | 629 | text_encoder.text_model.final_layer_norm.parameters(), |
| 604 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 630 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
| 605 | )) | 631 | )) |
| 606 | else: | ||
| 607 | placeholder_token_id = [] | ||
| 608 | 632 | ||
| 609 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 633 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 610 | 634 | ||
| @@ -700,6 +724,7 @@ def main(): | |||
| 700 | repeats=args.repeats, | 724 | repeats=args.repeats, |
| 701 | center_crop=args.center_crop, | 725 | center_crop=args.center_crop, |
| 702 | valid_set_size=args.sample_batch_size*args.sample_batches, | 726 | valid_set_size=args.sample_batch_size*args.sample_batches, |
| 727 | num_workers=args.dataloader_num_workers, | ||
| 703 | collate_fn=collate_fn | 728 | collate_fn=collate_fn |
| 704 | ) | 729 | ) |
| 705 | 730 | ||
| @@ -906,7 +931,7 @@ def main(): | |||
| 906 | 931 | ||
| 907 | accelerator.backward(loss) | 932 | accelerator.backward(loss) |
| 908 | 933 | ||
| 909 | if args.initializer_token is not None: | 934 | if not args.train_text_encoder: |
| 910 | # Keep the token embeddings fixed except the newly added | 935 | # Keep the token embeddings fixed except the newly added |
| 911 | # embeddings for the concept, as we only want to optimize the concept embeddings | 936 | # embeddings for the concept, as we only want to optimize the concept embeddings |
| 912 | if accelerator.num_processes > 1: | 937 | if accelerator.num_processes > 1: |
