diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py index db097e5..e71b7f0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -72,15 +72,23 @@ def parse_args(): | |||
72 | "--placeholder_token", | 72 | "--placeholder_token", |
73 | type=str, | 73 | type=str, |
74 | nargs='*', | 74 | nargs='*', |
75 | default=[], | ||
75 | help="A token to use as a placeholder for the concept.", | 76 | help="A token to use as a placeholder for the concept.", |
76 | ) | 77 | ) |
77 | parser.add_argument( | 78 | parser.add_argument( |
78 | "--initializer_token", | 79 | "--initializer_token", |
79 | type=str, | 80 | type=str, |
80 | nargs='*', | 81 | nargs='*', |
82 | default=[], | ||
81 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
82 | ) | 84 | ) |
83 | parser.add_argument( | 85 | parser.add_argument( |
86 | "--train_text_encoder", | ||
87 | action="store_true", | ||
88 | default=True, | ||
89 | help="Whether to train the whole text encoder." | ||
90 | ) | ||
91 | parser.add_argument( | ||
84 | "--num_class_images", | 92 | "--num_class_images", |
85 | type=int, | 93 | type=int, |
86 | default=400, | 94 | default=400, |
@@ -119,6 +127,15 @@ def parse_args(): | |||
119 | help="Whether to center crop images before resizing to resolution" | 127 | help="Whether to center crop images before resizing to resolution" |
120 | ) | 128 | ) |
121 | parser.add_argument( | 129 | parser.add_argument( |
130 | "--dataloader_num_workers", | ||
131 | type=int, | ||
132 | default=0, | ||
133 | help=( | ||
134 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
135 | " process." | ||
136 | ), | ||
137 | ) | ||
138 | parser.add_argument( | ||
122 | "--num_train_epochs", | 139 | "--num_train_epochs", |
123 | type=int, | 140 | type=int, |
124 | default=100 | 141 | default=100 |
@@ -323,7 +340,7 @@ def parse_args(): | |||
323 | args.placeholder_token = [args.placeholder_token] | 340 | args.placeholder_token = [args.placeholder_token] |
324 | 341 | ||
325 | if len(args.placeholder_token) == 0: | 342 | if len(args.placeholder_token) == 0: |
326 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 343 | args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] |
327 | 344 | ||
328 | if len(args.placeholder_token) != len(args.initializer_token): | 345 | if len(args.placeholder_token) != len(args.initializer_token): |
329 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 346 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
@@ -391,6 +408,9 @@ class Checkpointer: | |||
391 | 408 | ||
392 | @torch.no_grad() | 409 | @torch.no_grad() |
393 | def save_embedding(self, step, postfix): | 410 | def save_embedding(self, step, postfix): |
411 | if len(self.placeholder_token) == 0: | ||
412 | return | ||
413 | |||
394 | print("Saving checkpoint for step %d..." % step) | 414 | print("Saving checkpoint for step %d..." % step) |
395 | 415 | ||
396 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 416 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
@@ -406,9 +426,6 @@ class Checkpointer: | |||
406 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 426 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
407 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 427 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
408 | 428 | ||
409 | del unwrapped | ||
410 | del learned_embeds | ||
411 | |||
412 | @torch.no_grad() | 429 | @torch.no_grad() |
413 | def save_model(self): | 430 | def save_model(self): |
414 | print("Saving model...") | 431 | print("Saving model...") |
@@ -575,7 +592,9 @@ def main(): | |||
575 | # Freeze text_encoder and vae | 592 | # Freeze text_encoder and vae |
576 | freeze_params(vae.parameters()) | 593 | freeze_params(vae.parameters()) |
577 | 594 | ||
578 | if len(args.initializer_token) != 0: | 595 | if len(args.placeholder_token) != 0: |
596 | print(f"Adding text embeddings: {args.placeholder_token}") | ||
597 | |||
579 | # Convert the initializer_token, placeholder_token to ids | 598 | # Convert the initializer_token, placeholder_token to ids |
580 | initializer_token_ids = torch.stack([ | 599 | initializer_token_ids = torch.stack([ |
581 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 600 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
@@ -597,14 +616,19 @@ def main(): | |||
597 | 616 | ||
598 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 617 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
599 | token_embeds[token_id] = embeddings | 618 | token_embeds[token_id] = embeddings |
619 | else: | ||
620 | placeholder_token_id = [] | ||
621 | |||
622 | if args.train_text_encoder: | ||
623 | print(f"Training entire text encoder.") | ||
624 | else: | ||
625 | print(f"Training added text embeddings") | ||
600 | 626 | ||
601 | freeze_params(itertools.chain( | 627 | freeze_params(itertools.chain( |
602 | text_encoder.text_model.encoder.parameters(), | 628 | text_encoder.text_model.encoder.parameters(), |
603 | text_encoder.text_model.final_layer_norm.parameters(), | 629 | text_encoder.text_model.final_layer_norm.parameters(), |
604 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 630 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
605 | )) | 631 | )) |
606 | else: | ||
607 | placeholder_token_id = [] | ||
608 | 632 | ||
609 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 633 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
610 | 634 | ||
@@ -700,6 +724,7 @@ def main(): | |||
700 | repeats=args.repeats, | 724 | repeats=args.repeats, |
701 | center_crop=args.center_crop, | 725 | center_crop=args.center_crop, |
702 | valid_set_size=args.sample_batch_size*args.sample_batches, | 726 | valid_set_size=args.sample_batch_size*args.sample_batches, |
727 | num_workers=args.dataloader_num_workers, | ||
703 | collate_fn=collate_fn | 728 | collate_fn=collate_fn |
704 | ) | 729 | ) |
705 | 730 | ||
@@ -906,7 +931,7 @@ def main(): | |||
906 | 931 | ||
907 | accelerator.backward(loss) | 932 | accelerator.backward(loss) |
908 | 933 | ||
909 | if args.initializer_token is not None: | 934 | if not args.train_text_encoder: |
910 | # Keep the token embeddings fixed except the newly added | 935 | # Keep the token embeddings fixed except the newly added |
911 | # embeddings for the concept, as we only want to optimize the concept embeddings | 936 | # embeddings for the concept, as we only want to optimize the concept embeddings |
912 | if accelerator.num_processes > 1: | 937 | if accelerator.num_processes > 1: |