diff options
| -rw-r--r-- | data/csv.py | 10 | ||||
| -rw-r--r-- | dreambooth.py | 41 | ||||
| -rw-r--r-- | textual_inversion.py | 10 |
3 files changed, 50 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py index f9b5e39..6bd7f9b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -38,7 +38,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 38 | center_crop: bool = False, | 38 | center_crop: bool = False, |
| 39 | valid_set_size: Optional[int] = None, | 39 | valid_set_size: Optional[int] = None, |
| 40 | generator: Optional[torch.Generator] = None, | 40 | generator: Optional[torch.Generator] = None, |
| 41 | collate_fn=None | 41 | collate_fn=None, |
| 42 | num_workers: int = 0 | ||
| 42 | ): | 43 | ): |
| 43 | super().__init__() | 44 | super().__init__() |
| 44 | 45 | ||
| @@ -62,6 +63,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 62 | self.valid_set_size = valid_set_size | 63 | self.valid_set_size = valid_set_size |
| 63 | self.generator = generator | 64 | self.generator = generator |
| 64 | self.collate_fn = collate_fn | 65 | self.collate_fn = collate_fn |
| 66 | self.num_workers = num_workers | ||
| 65 | self.batch_size = batch_size | 67 | self.batch_size = batch_size |
| 66 | 68 | ||
| 67 | def prepare_subdata(self, template, data, num_class_images=1): | 69 | def prepare_subdata(self, template, data, num_class_images=1): |
| @@ -113,9 +115,11 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 113 | size=self.size, interpolation=self.interpolation, | 115 | size=self.size, interpolation=self.interpolation, |
| 114 | center_crop=self.center_crop) | 116 | center_crop=self.center_crop) |
| 115 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 117 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| 116 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 118 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, |
| 119 | num_workers=self.num_workers) | ||
| 117 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 120 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
| 118 | pin_memory=True, collate_fn=self.collate_fn) | 121 | pin_memory=True, collate_fn=self.collate_fn, |
| 122 | num_workers=self.num_workers) | ||
| 119 | 123 | ||
| 120 | def train_dataloader(self): | 124 | def train_dataloader(self): |
| 121 | return self.train_dataloader_ | 125 | return self.train_dataloader_ |
diff --git a/dreambooth.py b/dreambooth.py index db097e5..e71b7f0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -72,15 +72,23 @@ def parse_args(): | |||
| 72 | "--placeholder_token", | 72 | "--placeholder_token", |
| 73 | type=str, | 73 | type=str, |
| 74 | nargs='*', | 74 | nargs='*', |
| 75 | default=[], | ||
| 75 | help="A token to use as a placeholder for the concept.", | 76 | help="A token to use as a placeholder for the concept.", |
| 76 | ) | 77 | ) |
| 77 | parser.add_argument( | 78 | parser.add_argument( |
| 78 | "--initializer_token", | 79 | "--initializer_token", |
| 79 | type=str, | 80 | type=str, |
| 80 | nargs='*', | 81 | nargs='*', |
| 82 | default=[], | ||
| 81 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
| 82 | ) | 84 | ) |
| 83 | parser.add_argument( | 85 | parser.add_argument( |
| 86 | "--train_text_encoder", | ||
| 87 | action="store_true", | ||
| 88 | default=True, | ||
| 89 | help="Whether to train the whole text encoder." | ||
| 90 | ) | ||
| 91 | parser.add_argument( | ||
| 84 | "--num_class_images", | 92 | "--num_class_images", |
| 85 | type=int, | 93 | type=int, |
| 86 | default=400, | 94 | default=400, |
| @@ -119,6 +127,15 @@ def parse_args(): | |||
| 119 | help="Whether to center crop images before resizing to resolution" | 127 | help="Whether to center crop images before resizing to resolution" |
| 120 | ) | 128 | ) |
| 121 | parser.add_argument( | 129 | parser.add_argument( |
| 130 | "--dataloader_num_workers", | ||
| 131 | type=int, | ||
| 132 | default=0, | ||
| 133 | help=( | ||
| 134 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
| 135 | " process." | ||
| 136 | ), | ||
| 137 | ) | ||
| 138 | parser.add_argument( | ||
| 122 | "--num_train_epochs", | 139 | "--num_train_epochs", |
| 123 | type=int, | 140 | type=int, |
| 124 | default=100 | 141 | default=100 |
| @@ -323,7 +340,7 @@ def parse_args(): | |||
| 323 | args.placeholder_token = [args.placeholder_token] | 340 | args.placeholder_token = [args.placeholder_token] |
| 324 | 341 | ||
| 325 | if len(args.placeholder_token) == 0: | 342 | if len(args.placeholder_token) == 0: |
| 326 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 343 | args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] |
| 327 | 344 | ||
| 328 | if len(args.placeholder_token) != len(args.initializer_token): | 345 | if len(args.placeholder_token) != len(args.initializer_token): |
| 329 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 346 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
| @@ -391,6 +408,9 @@ class Checkpointer: | |||
| 391 | 408 | ||
| 392 | @torch.no_grad() | 409 | @torch.no_grad() |
| 393 | def save_embedding(self, step, postfix): | 410 | def save_embedding(self, step, postfix): |
| 411 | if len(self.placeholder_token) == 0: | ||
| 412 | return | ||
| 413 | |||
| 394 | print("Saving checkpoint for step %d..." % step) | 414 | print("Saving checkpoint for step %d..." % step) |
| 395 | 415 | ||
| 396 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 416 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
| @@ -406,9 +426,6 @@ class Checkpointer: | |||
| 406 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 426 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
| 407 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 427 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
| 408 | 428 | ||
| 409 | del unwrapped | ||
| 410 | del learned_embeds | ||
| 411 | |||
| 412 | @torch.no_grad() | 429 | @torch.no_grad() |
| 413 | def save_model(self): | 430 | def save_model(self): |
| 414 | print("Saving model...") | 431 | print("Saving model...") |
| @@ -575,7 +592,9 @@ def main(): | |||
| 575 | # Freeze text_encoder and vae | 592 | # Freeze text_encoder and vae |
| 576 | freeze_params(vae.parameters()) | 593 | freeze_params(vae.parameters()) |
| 577 | 594 | ||
| 578 | if len(args.initializer_token) != 0: | 595 | if len(args.placeholder_token) != 0: |
| 596 | print(f"Adding text embeddings: {args.placeholder_token}") | ||
| 597 | |||
| 579 | # Convert the initializer_token, placeholder_token to ids | 598 | # Convert the initializer_token, placeholder_token to ids |
| 580 | initializer_token_ids = torch.stack([ | 599 | initializer_token_ids = torch.stack([ |
| 581 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 600 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
| @@ -597,14 +616,19 @@ def main(): | |||
| 597 | 616 | ||
| 598 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 617 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
| 599 | token_embeds[token_id] = embeddings | 618 | token_embeds[token_id] = embeddings |
| 619 | else: | ||
| 620 | placeholder_token_id = [] | ||
| 621 | |||
| 622 | if args.train_text_encoder: | ||
| 623 | print(f"Training entire text encoder.") | ||
| 624 | else: | ||
| 625 | print(f"Training added text embeddings") | ||
| 600 | 626 | ||
| 601 | freeze_params(itertools.chain( | 627 | freeze_params(itertools.chain( |
| 602 | text_encoder.text_model.encoder.parameters(), | 628 | text_encoder.text_model.encoder.parameters(), |
| 603 | text_encoder.text_model.final_layer_norm.parameters(), | 629 | text_encoder.text_model.final_layer_norm.parameters(), |
| 604 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 630 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
| 605 | )) | 631 | )) |
| 606 | else: | ||
| 607 | placeholder_token_id = [] | ||
| 608 | 632 | ||
| 609 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 633 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 610 | 634 | ||
| @@ -700,6 +724,7 @@ def main(): | |||
| 700 | repeats=args.repeats, | 724 | repeats=args.repeats, |
| 701 | center_crop=args.center_crop, | 725 | center_crop=args.center_crop, |
| 702 | valid_set_size=args.sample_batch_size*args.sample_batches, | 726 | valid_set_size=args.sample_batch_size*args.sample_batches, |
| 727 | num_workers=args.dataloader_num_workers, | ||
| 703 | collate_fn=collate_fn | 728 | collate_fn=collate_fn |
| 704 | ) | 729 | ) |
| 705 | 730 | ||
| @@ -906,7 +931,7 @@ def main(): | |||
| 906 | 931 | ||
| 907 | accelerator.backward(loss) | 932 | accelerator.backward(loss) |
| 908 | 933 | ||
| 909 | if args.initializer_token is not None: | 934 | if not args.train_text_encoder: |
| 910 | # Keep the token embeddings fixed except the newly added | 935 | # Keep the token embeddings fixed except the newly added |
| 911 | # embeddings for the concept, as we only want to optimize the concept embeddings | 936 | # embeddings for the concept, as we only want to optimize the concept embeddings |
| 912 | if accelerator.num_processes > 1: | 937 | if accelerator.num_processes > 1: |
diff --git a/textual_inversion.py b/textual_inversion.py index dd7c3bd..115f3aa 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -117,6 +117,15 @@ def parse_args(): | |||
| 117 | help="Whether to center crop images before resizing to resolution" | 117 | help="Whether to center crop images before resizing to resolution" |
| 118 | ) | 118 | ) |
| 119 | parser.add_argument( | 119 | parser.add_argument( |
| 120 | "--dataloader_num_workers", | ||
| 121 | type=int, | ||
| 122 | default=0, | ||
| 123 | help=( | ||
| 124 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
| 125 | " process." | ||
| 126 | ), | ||
| 127 | ) | ||
| 128 | parser.add_argument( | ||
| 120 | "--num_train_epochs", | 129 | "--num_train_epochs", |
| 121 | type=int, | 130 | type=int, |
| 122 | default=100 | 131 | default=100 |
| @@ -626,6 +635,7 @@ def main(): | |||
| 626 | repeats=args.repeats, | 635 | repeats=args.repeats, |
| 627 | center_crop=args.center_crop, | 636 | center_crop=args.center_crop, |
| 628 | valid_set_size=args.sample_batch_size*args.sample_batches, | 637 | valid_set_size=args.sample_batch_size*args.sample_batches, |
| 638 | num_workers=args.dataloader_num_workers, | ||
| 629 | collate_fn=collate_fn | 639 | collate_fn=collate_fn |
| 630 | ) | 640 | ) |
| 631 | 641 | ||
