diff options
author | Volpeon <git@volpeon.ink> | 2022-10-27 17:57:05 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-27 17:57:05 +0200 |
commit | 0bc909409648a3cae0061c3de2b39e486473ae39 (patch) | |
tree | 5fdbcd7c56919293963c3c8b53bdb2099834079d | |
parent | Euler_a: Re-introduce generator arg for reproducible output (diff) | |
download | textual-inversion-diff-0bc909409648a3cae0061c3de2b39e486473ae39.tar.gz textual-inversion-diff-0bc909409648a3cae0061c3de2b39e486473ae39.tar.bz2 textual-inversion-diff-0bc909409648a3cae0061c3de2b39e486473ae39.zip |
Added CLI arg to set dataloader worker num; improved text encoder handling with Dreambooth
-rw-r--r-- | data/csv.py | 10 | ||||
-rw-r--r-- | dreambooth.py | 41 | ||||
-rw-r--r-- | textual_inversion.py | 10 |
3 files changed, 50 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py index f9b5e39..6bd7f9b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -38,7 +38,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
38 | center_crop: bool = False, | 38 | center_crop: bool = False, |
39 | valid_set_size: Optional[int] = None, | 39 | valid_set_size: Optional[int] = None, |
40 | generator: Optional[torch.Generator] = None, | 40 | generator: Optional[torch.Generator] = None, |
41 | collate_fn=None | 41 | collate_fn=None, |
42 | num_workers: int = 0 | ||
42 | ): | 43 | ): |
43 | super().__init__() | 44 | super().__init__() |
44 | 45 | ||
@@ -62,6 +63,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
62 | self.valid_set_size = valid_set_size | 63 | self.valid_set_size = valid_set_size |
63 | self.generator = generator | 64 | self.generator = generator |
64 | self.collate_fn = collate_fn | 65 | self.collate_fn = collate_fn |
66 | self.num_workers = num_workers | ||
65 | self.batch_size = batch_size | 67 | self.batch_size = batch_size |
66 | 68 | ||
67 | def prepare_subdata(self, template, data, num_class_images=1): | 69 | def prepare_subdata(self, template, data, num_class_images=1): |
@@ -113,9 +115,11 @@ class CSVDataModule(pl.LightningDataModule): | |||
113 | size=self.size, interpolation=self.interpolation, | 115 | size=self.size, interpolation=self.interpolation, |
114 | center_crop=self.center_crop) | 116 | center_crop=self.center_crop) |
115 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 117 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
116 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 118 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, |
119 | num_workers=self.num_workers) | ||
117 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 120 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
118 | pin_memory=True, collate_fn=self.collate_fn) | 121 | pin_memory=True, collate_fn=self.collate_fn, |
122 | num_workers=self.num_workers) | ||
119 | 123 | ||
120 | def train_dataloader(self): | 124 | def train_dataloader(self): |
121 | return self.train_dataloader_ | 125 | return self.train_dataloader_ |
diff --git a/dreambooth.py b/dreambooth.py index db097e5..e71b7f0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -72,15 +72,23 @@ def parse_args(): | |||
72 | "--placeholder_token", | 72 | "--placeholder_token", |
73 | type=str, | 73 | type=str, |
74 | nargs='*', | 74 | nargs='*', |
75 | default=[], | ||
75 | help="A token to use as a placeholder for the concept.", | 76 | help="A token to use as a placeholder for the concept.", |
76 | ) | 77 | ) |
77 | parser.add_argument( | 78 | parser.add_argument( |
78 | "--initializer_token", | 79 | "--initializer_token", |
79 | type=str, | 80 | type=str, |
80 | nargs='*', | 81 | nargs='*', |
82 | default=[], | ||
81 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
82 | ) | 84 | ) |
83 | parser.add_argument( | 85 | parser.add_argument( |
86 | "--train_text_encoder", | ||
87 | action="store_true", | ||
88 | default=True, | ||
89 | help="Whether to train the whole text encoder." | ||
90 | ) | ||
91 | parser.add_argument( | ||
84 | "--num_class_images", | 92 | "--num_class_images", |
85 | type=int, | 93 | type=int, |
86 | default=400, | 94 | default=400, |
@@ -119,6 +127,15 @@ def parse_args(): | |||
119 | help="Whether to center crop images before resizing to resolution" | 127 | help="Whether to center crop images before resizing to resolution" |
120 | ) | 128 | ) |
121 | parser.add_argument( | 129 | parser.add_argument( |
130 | "--dataloader_num_workers", | ||
131 | type=int, | ||
132 | default=0, | ||
133 | help=( | ||
134 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
135 | " process." | ||
136 | ), | ||
137 | ) | ||
138 | parser.add_argument( | ||
122 | "--num_train_epochs", | 139 | "--num_train_epochs", |
123 | type=int, | 140 | type=int, |
124 | default=100 | 141 | default=100 |
@@ -323,7 +340,7 @@ def parse_args(): | |||
323 | args.placeholder_token = [args.placeholder_token] | 340 | args.placeholder_token = [args.placeholder_token] |
324 | 341 | ||
325 | if len(args.placeholder_token) == 0: | 342 | if len(args.placeholder_token) == 0: |
326 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 343 | args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] |
327 | 344 | ||
328 | if len(args.placeholder_token) != len(args.initializer_token): | 345 | if len(args.placeholder_token) != len(args.initializer_token): |
329 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 346 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
@@ -391,6 +408,9 @@ class Checkpointer: | |||
391 | 408 | ||
392 | @torch.no_grad() | 409 | @torch.no_grad() |
393 | def save_embedding(self, step, postfix): | 410 | def save_embedding(self, step, postfix): |
411 | if len(self.placeholder_token) == 0: | ||
412 | return | ||
413 | |||
394 | print("Saving checkpoint for step %d..." % step) | 414 | print("Saving checkpoint for step %d..." % step) |
395 | 415 | ||
396 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 416 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
@@ -406,9 +426,6 @@ class Checkpointer: | |||
406 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 426 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
407 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 427 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
408 | 428 | ||
409 | del unwrapped | ||
410 | del learned_embeds | ||
411 | |||
412 | @torch.no_grad() | 429 | @torch.no_grad() |
413 | def save_model(self): | 430 | def save_model(self): |
414 | print("Saving model...") | 431 | print("Saving model...") |
@@ -575,7 +592,9 @@ def main(): | |||
575 | # Freeze text_encoder and vae | 592 | # Freeze text_encoder and vae |
576 | freeze_params(vae.parameters()) | 593 | freeze_params(vae.parameters()) |
577 | 594 | ||
578 | if len(args.initializer_token) != 0: | 595 | if len(args.placeholder_token) != 0: |
596 | print(f"Adding text embeddings: {args.placeholder_token}") | ||
597 | |||
579 | # Convert the initializer_token, placeholder_token to ids | 598 | # Convert the initializer_token, placeholder_token to ids |
580 | initializer_token_ids = torch.stack([ | 599 | initializer_token_ids = torch.stack([ |
581 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 600 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
@@ -597,14 +616,19 @@ def main(): | |||
597 | 616 | ||
598 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 617 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
599 | token_embeds[token_id] = embeddings | 618 | token_embeds[token_id] = embeddings |
619 | else: | ||
620 | placeholder_token_id = [] | ||
621 | |||
622 | if args.train_text_encoder: | ||
623 | print(f"Training entire text encoder.") | ||
624 | else: | ||
625 | print(f"Training added text embeddings") | ||
600 | 626 | ||
601 | freeze_params(itertools.chain( | 627 | freeze_params(itertools.chain( |
602 | text_encoder.text_model.encoder.parameters(), | 628 | text_encoder.text_model.encoder.parameters(), |
603 | text_encoder.text_model.final_layer_norm.parameters(), | 629 | text_encoder.text_model.final_layer_norm.parameters(), |
604 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 630 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
605 | )) | 631 | )) |
606 | else: | ||
607 | placeholder_token_id = [] | ||
608 | 632 | ||
609 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 633 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
610 | 634 | ||
@@ -700,6 +724,7 @@ def main(): | |||
700 | repeats=args.repeats, | 724 | repeats=args.repeats, |
701 | center_crop=args.center_crop, | 725 | center_crop=args.center_crop, |
702 | valid_set_size=args.sample_batch_size*args.sample_batches, | 726 | valid_set_size=args.sample_batch_size*args.sample_batches, |
727 | num_workers=args.dataloader_num_workers, | ||
703 | collate_fn=collate_fn | 728 | collate_fn=collate_fn |
704 | ) | 729 | ) |
705 | 730 | ||
@@ -906,7 +931,7 @@ def main(): | |||
906 | 931 | ||
907 | accelerator.backward(loss) | 932 | accelerator.backward(loss) |
908 | 933 | ||
909 | if args.initializer_token is not None: | 934 | if not args.train_text_encoder: |
910 | # Keep the token embeddings fixed except the newly added | 935 | # Keep the token embeddings fixed except the newly added |
911 | # embeddings for the concept, as we only want to optimize the concept embeddings | 936 | # embeddings for the concept, as we only want to optimize the concept embeddings |
912 | if accelerator.num_processes > 1: | 937 | if accelerator.num_processes > 1: |
diff --git a/textual_inversion.py b/textual_inversion.py index dd7c3bd..115f3aa 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -117,6 +117,15 @@ def parse_args(): | |||
117 | help="Whether to center crop images before resizing to resolution" | 117 | help="Whether to center crop images before resizing to resolution" |
118 | ) | 118 | ) |
119 | parser.add_argument( | 119 | parser.add_argument( |
120 | "--dataloader_num_workers", | ||
121 | type=int, | ||
122 | default=0, | ||
123 | help=( | ||
124 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
125 | " process." | ||
126 | ), | ||
127 | ) | ||
128 | parser.add_argument( | ||
120 | "--num_train_epochs", | 129 | "--num_train_epochs", |
121 | type=int, | 130 | type=int, |
122 | default=100 | 131 | default=100 |
@@ -626,6 +635,7 @@ def main(): | |||
626 | repeats=args.repeats, | 635 | repeats=args.repeats, |
627 | center_crop=args.center_crop, | 636 | center_crop=args.center_crop, |
628 | valid_set_size=args.sample_batch_size*args.sample_batches, | 637 | valid_set_size=args.sample_batch_size*args.sample_batches, |
638 | num_workers=args.dataloader_num_workers, | ||
629 | collate_fn=collate_fn | 639 | collate_fn=collate_fn |
630 | ) | 640 | ) |
631 | 641 | ||