summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 10:51:02 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 10:51:02 +0100
commit9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f (patch)
treeb365457b40a54fed792f2e3e9d776389a5c9017f
parentHandle empty validation dataset (diff)
downloadtextual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.tar.gz
textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.tar.bz2
textual-inversion-diff-9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f.zip
Pad dataset if len(items) < batch_size
-rw-r--r--data/csv.py3
-rw-r--r--train_dreambooth.py40
2 files changed, 23 insertions, 20 deletions
diff --git a/data/csv.py b/data/csv.py
index 968af8d..dec66d7 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -267,6 +267,9 @@ class VlpnDataModule():
267 items = self.prepare_items(template, expansions, items) 267 items = self.prepare_items(template, expansions, items)
268 items = self.filter_items(items) 268 items = self.filter_items(items)
269 269
270 if (len(items) < self.batch_size):
271 items = (items * self.batch_size)[:self.batch_size]
272
270 num_images = len(items) 273 num_images = len(items)
271 274
272 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 275 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 4e41f77..a9fbbbd 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -87,12 +87,6 @@ def parse_args():
87 help="Exclude all items with a listed collection.", 87 help="Exclude all items with a listed collection.",
88 ) 88 )
89 parser.add_argument( 89 parser.add_argument(
90 "--train_text_encoder",
91 action="store_true",
92 default=True,
93 help="Whether to train the whole text encoder."
94 )
95 parser.add_argument(
96 "--train_text_encoder_epochs", 90 "--train_text_encoder_epochs",
97 default=999999, 91 default=999999,
98 help="Number of epochs the text encoder will be trained." 92 help="Number of epochs the text encoder will be trained."
@@ -194,6 +188,11 @@ def parse_args():
194 default=100 188 default=100
195 ) 189 )
196 parser.add_argument( 190 parser.add_argument(
191 "--ti_num_train_epochs",
192 type=int,
193 default=10
194 )
195 parser.add_argument(
197 "--max_train_steps", 196 "--max_train_steps",
198 type=int, 197 type=int,
199 default=None, 198 default=None,
@@ -222,6 +221,12 @@ def parse_args():
222 help="Initial learning rate (after the potential warmup period) to use.", 221 help="Initial learning rate (after the potential warmup period) to use.",
223 ) 222 )
224 parser.add_argument( 223 parser.add_argument(
224 "--ti_learning_rate",
225 type=float,
226 default=1e-2,
227 help="Initial learning rate (after the potential warmup period) to use.",
228 )
229 parser.add_argument(
225 "--scale_lr", 230 "--scale_lr",
226 action="store_true", 231 action="store_true",
227 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 232 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
@@ -541,15 +546,14 @@ def main():
541 noise_scheduler=noise_scheduler, 546 noise_scheduler=noise_scheduler,
542 dtype=weight_dtype, 547 dtype=weight_dtype,
543 seed=args.seed, 548 seed=args.seed,
549 with_prior_preservation=args.num_class_images != 0,
550 prior_loss_weight=args.prior_loss_weight,
544 ) 551 )
545 552
546 # Initial TI 553 # Initial TI
547 554
548 print("Phase 1: Textual Inversion") 555 print("Phase 1: Textual Inversion")
549 556
550 ti_lr = 1e-1
551 ti_train_epochs = 10
552
553 cur_dir = output_dir.joinpath("1-ti") 557 cur_dir = output_dir.joinpath("1-ti")
554 cur_dir.mkdir(parents=True, exist_ok=True) 558 cur_dir.mkdir(parents=True, exist_ok=True)
555 559
@@ -576,7 +580,7 @@ def main():
576 size=args.resolution, 580 size=args.resolution,
577 shuffle=not args.no_tag_shuffle, 581 shuffle=not args.no_tag_shuffle,
578 template_key=args.train_data_template, 582 template_key=args.train_data_template,
579 valid_set_size=args.valid_set_size, 583 valid_set_size=1,
580 valid_set_repeat=args.valid_set_repeat, 584 valid_set_repeat=args.valid_set_repeat,
581 seed=args.seed, 585 seed=args.seed,
582 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), 586 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
@@ -586,7 +590,7 @@ def main():
586 590
587 optimizer = optimizer_class( 591 optimizer = optimizer_class(
588 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 592 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
589 lr=ti_lr, 593 lr=args.ti_learning_rate,
590 weight_decay=0.0, 594 weight_decay=0.0,
591 ) 595 )
592 596
@@ -595,8 +599,8 @@ def main():
595 optimizer=optimizer, 599 optimizer=optimizer,
596 num_training_steps_per_epoch=len(datamodule.train_dataloader), 600 num_training_steps_per_epoch=len(datamodule.train_dataloader),
597 gradient_accumulation_steps=args.gradient_accumulation_steps, 601 gradient_accumulation_steps=args.gradient_accumulation_steps,
598 train_epochs=ti_train_epochs, 602 train_epochs=args.ti_num_train_epochs,
599 warmup_epochs=ti_train_epochs // 4, 603 warmup_epochs=args.ti_num_train_epochs // 4,
600 ) 604 )
601 605
602 trainer( 606 trainer(
@@ -606,18 +610,16 @@ def main():
606 val_dataloader=datamodule.val_dataloader, 610 val_dataloader=datamodule.val_dataloader,
607 optimizer=optimizer, 611 optimizer=optimizer,
608 lr_scheduler=lr_scheduler, 612 lr_scheduler=lr_scheduler,
609 num_train_epochs=ti_train_epochs, 613 num_train_epochs=args.ti_num_train_epochs,
610 sample_frequency=1, 614 sample_frequency=2,
611 checkpoint_frequency=9999999, 615 checkpoint_frequency=9999999,
612 with_prior_preservation=args.num_class_images != 0,
613 prior_loss_weight=args.prior_loss_weight,
614 # -- 616 # --
615 tokenizer=tokenizer, 617 tokenizer=tokenizer,
616 sample_scheduler=sample_scheduler, 618 sample_scheduler=sample_scheduler,
617 output_dir=cur_subdir, 619 output_dir=cur_subdir,
618 placeholder_tokens=[placeholder_token], 620 placeholder_tokens=[placeholder_token],
619 placeholder_token_ids=placeholder_token_ids, 621 placeholder_token_ids=placeholder_token_ids,
620 learning_rate=ti_lr, 622 learning_rate=args.ti_learning_rate,
621 gradient_checkpointing=args.gradient_checkpointing, 623 gradient_checkpointing=args.gradient_checkpointing,
622 use_emb_decay=True, 624 use_emb_decay=True,
623 sample_batch_size=args.sample_batch_size, 625 sample_batch_size=args.sample_batch_size,
@@ -699,8 +701,6 @@ def main():
699 lr_scheduler=lr_scheduler, 701 lr_scheduler=lr_scheduler,
700 num_train_epochs=args.num_train_epochs, 702 num_train_epochs=args.num_train_epochs,
701 sample_frequency=args.sample_frequency, 703 sample_frequency=args.sample_frequency,
702 with_prior_preservation=args.num_class_images != 0,
703 prior_loss_weight=args.prior_loss_weight,
704 # -- 704 # --
705 tokenizer=tokenizer, 705 tokenizer=tokenizer,
706 sample_scheduler=sample_scheduler, 706 sample_scheduler=sample_scheduler,