diff options
author | Volpeon <git@volpeon.ink> | 2023-01-16 10:51:02 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-16 10:51:02 +0100 |
commit | 9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f (patch) | |
tree | b365457b40a54fed792f2e3e9d776389a5c9017f | |
parent | Handle empty validation dataset (diff) | |
download | textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.tar.gz textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.tar.bz2 textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.zip |
Pad dataset if len(items) < batch_size
-rw-r--r-- | data/csv.py | 3 | ||||
-rw-r--r-- | train_dreambooth.py | 40 |
2 files changed, 23 insertions, 20 deletions
diff --git a/data/csv.py b/data/csv.py index 968af8d..dec66d7 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -267,6 +267,9 @@ class VlpnDataModule(): | |||
267 | items = self.prepare_items(template, expansions, items) | 267 | items = self.prepare_items(template, expansions, items) |
268 | items = self.filter_items(items) | 268 | items = self.filter_items(items) |
269 | 269 | ||
270 | if (len(items) < self.batch_size): | ||
271 | items = (items * self.batch_size)[:self.batch_size] | ||
272 | |||
270 | num_images = len(items) | 273 | num_images = len(items) |
271 | 274 | ||
272 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 275 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 4e41f77..a9fbbbd 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -87,12 +87,6 @@ def parse_args(): | |||
87 | help="Exclude all items with a listed collection.", | 87 | help="Exclude all items with a listed collection.", |
88 | ) | 88 | ) |
89 | parser.add_argument( | 89 | parser.add_argument( |
90 | "--train_text_encoder", | ||
91 | action="store_true", | ||
92 | default=True, | ||
93 | help="Whether to train the whole text encoder." | ||
94 | ) | ||
95 | parser.add_argument( | ||
96 | "--train_text_encoder_epochs", | 90 | "--train_text_encoder_epochs", |
97 | default=999999, | 91 | default=999999, |
98 | help="Number of epochs the text encoder will be trained." | 92 | help="Number of epochs the text encoder will be trained." |
@@ -194,6 +188,11 @@ def parse_args(): | |||
194 | default=100 | 188 | default=100 |
195 | ) | 189 | ) |
196 | parser.add_argument( | 190 | parser.add_argument( |
191 | "--ti_num_train_epochs", | ||
192 | type=int, | ||
193 | default=10 | ||
194 | ) | ||
195 | parser.add_argument( | ||
197 | "--max_train_steps", | 196 | "--max_train_steps", |
198 | type=int, | 197 | type=int, |
199 | default=None, | 198 | default=None, |
@@ -222,6 +221,12 @@ def parse_args(): | |||
222 | help="Initial learning rate (after the potential warmup period) to use.", | 221 | help="Initial learning rate (after the potential warmup period) to use.", |
223 | ) | 222 | ) |
224 | parser.add_argument( | 223 | parser.add_argument( |
224 | "--ti_learning_rate", | ||
225 | type=float, | ||
226 | default=1e-2, | ||
227 | help="Initial learning rate (after the potential warmup period) to use.", | ||
228 | ) | ||
229 | parser.add_argument( | ||
225 | "--scale_lr", | 230 | "--scale_lr", |
226 | action="store_true", | 231 | action="store_true", |
227 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 232 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
@@ -541,15 +546,14 @@ def main(): | |||
541 | noise_scheduler=noise_scheduler, | 546 | noise_scheduler=noise_scheduler, |
542 | dtype=weight_dtype, | 547 | dtype=weight_dtype, |
543 | seed=args.seed, | 548 | seed=args.seed, |
549 | with_prior_preservation=args.num_class_images != 0, | ||
550 | prior_loss_weight=args.prior_loss_weight, | ||
544 | ) | 551 | ) |
545 | 552 | ||
546 | # Initial TI | 553 | # Initial TI |
547 | 554 | ||
548 | print("Phase 1: Textual Inversion") | 555 | print("Phase 1: Textual Inversion") |
549 | 556 | ||
550 | ti_lr = 1e-1 | ||
551 | ti_train_epochs = 10 | ||
552 | |||
553 | cur_dir = output_dir.joinpath("1-ti") | 557 | cur_dir = output_dir.joinpath("1-ti") |
554 | cur_dir.mkdir(parents=True, exist_ok=True) | 558 | cur_dir.mkdir(parents=True, exist_ok=True) |
555 | 559 | ||
@@ -576,7 +580,7 @@ def main(): | |||
576 | size=args.resolution, | 580 | size=args.resolution, |
577 | shuffle=not args.no_tag_shuffle, | 581 | shuffle=not args.no_tag_shuffle, |
578 | template_key=args.train_data_template, | 582 | template_key=args.train_data_template, |
579 | valid_set_size=args.valid_set_size, | 583 | valid_set_size=1, |
580 | valid_set_repeat=args.valid_set_repeat, | 584 | valid_set_repeat=args.valid_set_repeat, |
581 | seed=args.seed, | 585 | seed=args.seed, |
582 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), | 586 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), |
@@ -586,7 +590,7 @@ def main(): | |||
586 | 590 | ||
587 | optimizer = optimizer_class( | 591 | optimizer = optimizer_class( |
588 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 592 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
589 | lr=ti_lr, | 593 | lr=args.ti_learning_rate, |
590 | weight_decay=0.0, | 594 | weight_decay=0.0, |
591 | ) | 595 | ) |
592 | 596 | ||
@@ -595,8 +599,8 @@ def main(): | |||
595 | optimizer=optimizer, | 599 | optimizer=optimizer, |
596 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 600 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
597 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 601 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
598 | train_epochs=ti_train_epochs, | 602 | train_epochs=args.ti_num_train_epochs, |
599 | warmup_epochs=ti_train_epochs // 4, | 603 | warmup_epochs=args.ti_num_train_epochs // 4, |
600 | ) | 604 | ) |
601 | 605 | ||
602 | trainer( | 606 | trainer( |
@@ -606,18 +610,16 @@ def main(): | |||
606 | val_dataloader=datamodule.val_dataloader, | 610 | val_dataloader=datamodule.val_dataloader, |
607 | optimizer=optimizer, | 611 | optimizer=optimizer, |
608 | lr_scheduler=lr_scheduler, | 612 | lr_scheduler=lr_scheduler, |
609 | num_train_epochs=ti_train_epochs, | 613 | num_train_epochs=args.ti_num_train_epochs, |
610 | sample_frequency=1, | 614 | sample_frequency=2, |
611 | checkpoint_frequency=9999999, | 615 | checkpoint_frequency=9999999, |
612 | with_prior_preservation=args.num_class_images != 0, | ||
613 | prior_loss_weight=args.prior_loss_weight, | ||
614 | # -- | 616 | # -- |
615 | tokenizer=tokenizer, | 617 | tokenizer=tokenizer, |
616 | sample_scheduler=sample_scheduler, | 618 | sample_scheduler=sample_scheduler, |
617 | output_dir=cur_subdir, | 619 | output_dir=cur_subdir, |
618 | placeholder_tokens=[placeholder_token], | 620 | placeholder_tokens=[placeholder_token], |
619 | placeholder_token_ids=placeholder_token_ids, | 621 | placeholder_token_ids=placeholder_token_ids, |
620 | learning_rate=ti_lr, | 622 | learning_rate=args.ti_learning_rate, |
621 | gradient_checkpointing=args.gradient_checkpointing, | 623 | gradient_checkpointing=args.gradient_checkpointing, |
622 | use_emb_decay=True, | 624 | use_emb_decay=True, |
623 | sample_batch_size=args.sample_batch_size, | 625 | sample_batch_size=args.sample_batch_size, |
@@ -699,8 +701,6 @@ def main(): | |||
699 | lr_scheduler=lr_scheduler, | 701 | lr_scheduler=lr_scheduler, |
700 | num_train_epochs=args.num_train_epochs, | 702 | num_train_epochs=args.num_train_epochs, |
701 | sample_frequency=args.sample_frequency, | 703 | sample_frequency=args.sample_frequency, |
702 | with_prior_preservation=args.num_class_images != 0, | ||
703 | prior_loss_weight=args.prior_loss_weight, | ||
704 | # -- | 704 | # -- |
705 | tokenizer=tokenizer, | 705 | tokenizer=tokenizer, |
706 | sample_scheduler=sample_scheduler, | 706 | sample_scheduler=sample_scheduler, |