summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 15:52:43 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 15:52:43 +0100
commit6c8cffe28baeafac77d047ff3f8ded9418033e2f (patch)
tree807c527deb1b15ef795f5cd8a7682151c69a037e /train_dreambooth.py
parentPad dataset if len(items) < batch_size (diff)
downloadtextual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.tar.gz
textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.tar.bz2
textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.zip
More training adjustments
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py71
1 files changed, 59 insertions, 12 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index a9fbbbd..1dc41b1 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -55,6 +55,18 @@ def parse_args():
55 default="template", 55 default="template",
56 ) 56 )
57 parser.add_argument( 57 parser.add_argument(
58 "--train_set_pad",
59 type=int,
60 default=None,
61 help="The number to fill train dataset items up to."
62 )
63 parser.add_argument(
64 "--valid_set_pad",
65 type=int,
66 default=None,
67 help="The number to fill validation dataset items up to."
68 )
69 parser.add_argument(
58 "--project", 70 "--project",
59 type=str, 71 type=str,
60 default=None, 72 default=None,
@@ -188,11 +200,23 @@ def parse_args():
188 default=100 200 default=100
189 ) 201 )
190 parser.add_argument( 202 parser.add_argument(
203 "--ti_data_template",
204 type=str,
205 nargs='*',
206 default=[],
207 )
208 parser.add_argument(
191 "--ti_num_train_epochs", 209 "--ti_num_train_epochs",
192 type=int, 210 type=int,
193 default=10 211 default=10
194 ) 212 )
195 parser.add_argument( 213 parser.add_argument(
214 "--ti_batch_size",
215 type=int,
216 default=1,
217 help="Batch size (per device) for the training dataloader."
218 )
219 parser.add_argument(
196 "--max_train_steps", 220 "--max_train_steps",
197 type=int, 221 type=int,
198 default=None, 222 default=None,
@@ -458,6 +482,12 @@ def parse_args():
458 if len(args.placeholder_tokens) != len(args.num_vectors): 482 if len(args.placeholder_tokens) != len(args.num_vectors):
459 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 483 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
460 484
485 if isinstance(args.ti_data_template, str):
486 args.ti_data_template = [args.ti_data_template]
487
488 if len(args.ti_data_template) == 0:
489 raise ValueError("You must specify --ti_data_template")
490
461 if isinstance(args.collection, str): 491 if isinstance(args.collection, str):
462 args.collection = [args.collection] 492 args.collection = [args.collection]
463 493
@@ -491,6 +521,8 @@ def main():
491 521
492 set_seed(args.seed) 522 set_seed(args.seed)
493 523
524 seed_generator = torch.Generator().manual_seed(args.seed)
525
494 save_args(output_dir, args) 526 save_args(output_dir, args)
495 527
496 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 528 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
@@ -512,6 +544,8 @@ def main():
512 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 544 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
513 raise ValueError("--embeddings_dir must point to an existing directory") 545 raise ValueError("--embeddings_dir must point to an existing directory")
514 546
547 embeddings.persist()
548
515 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 549 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
516 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 550 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
517 551
@@ -545,7 +579,6 @@ def main():
545 vae=vae, 579 vae=vae,
546 noise_scheduler=noise_scheduler, 580 noise_scheduler=noise_scheduler,
547 dtype=weight_dtype, 581 dtype=weight_dtype,
548 seed=args.seed,
549 with_prior_preservation=args.num_class_images != 0, 582 with_prior_preservation=args.num_class_images != 0,
550 prior_loss_weight=args.prior_loss_weight, 583 prior_loss_weight=args.prior_loss_weight,
551 ) 584 )
@@ -557,13 +590,17 @@ def main():
557 cur_dir = output_dir.joinpath("1-ti") 590 cur_dir = output_dir.joinpath("1-ti")
558 cur_dir.mkdir(parents=True, exist_ok=True) 591 cur_dir.mkdir(parents=True, exist_ok=True)
559 592
560 for placeholder_token, initializer_token, num_vectors in zip(args.placeholder_tokens, args.initializer_tokens, args.num_vectors): 593 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
561 print(f"Phase 1.1: {placeholder_token} ({num_vectors}) ({initializer_token})") 594 range(len(args.placeholder_tokens)),
562 595 args.placeholder_tokens,
596 args.initializer_tokens,
597 args.num_vectors,
598 args.ti_data_template
599 ):
563 cur_subdir = cur_dir.joinpath(placeholder_token) 600 cur_subdir = cur_dir.joinpath(placeholder_token)
564 cur_subdir.mkdir(parents=True, exist_ok=True) 601 cur_subdir.mkdir(parents=True, exist_ok=True)
565 602
566 placeholder_token_ids, _ = add_placeholder_tokens( 603 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
567 tokenizer=tokenizer, 604 tokenizer=tokenizer,
568 embeddings=embeddings, 605 embeddings=embeddings,
569 placeholder_tokens=[placeholder_token], 606 placeholder_tokens=[placeholder_token],
@@ -571,17 +608,23 @@ def main():
571 num_vectors=[num_vectors] 608 num_vectors=[num_vectors]
572 ) 609 )
573 610
611 print(
612 f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})")
613
614 args.seed = seed_generator.seed()
615
574 datamodule = VlpnDataModule( 616 datamodule = VlpnDataModule(
575 data_file=args.train_data_file, 617 data_file=args.train_data_file,
576 batch_size=args.train_batch_size, 618 batch_size=args.ti_batch_size,
577 tokenizer=tokenizer, 619 tokenizer=tokenizer,
578 class_subdir=args.class_image_dir, 620 class_subdir=args.class_image_dir,
579 num_class_images=args.num_class_images, 621 num_class_images=args.num_class_images,
580 size=args.resolution, 622 size=args.resolution,
581 shuffle=not args.no_tag_shuffle, 623 shuffle=not args.no_tag_shuffle,
582 template_key=args.train_data_template, 624 template_key=data_template,
583 valid_set_size=1, 625 valid_set_size=1,
584 valid_set_repeat=args.valid_set_repeat, 626 train_set_pad=args.train_set_pad,
627 valid_set_pad=args.valid_set_pad,
585 seed=args.seed, 628 seed=args.seed,
586 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), 629 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
587 dtype=weight_dtype 630 dtype=weight_dtype
@@ -591,7 +634,9 @@ def main():
591 optimizer = optimizer_class( 634 optimizer = optimizer_class(
592 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 635 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
593 lr=args.ti_learning_rate, 636 lr=args.ti_learning_rate,
637 betas=(args.adam_beta1, args.adam_beta2),
594 weight_decay=0.0, 638 weight_decay=0.0,
639 eps=args.adam_epsilon,
595 ) 640 )
596 641
597 lr_scheduler = get_scheduler( 642 lr_scheduler = get_scheduler(
@@ -600,7 +645,6 @@ def main():
600 num_training_steps_per_epoch=len(datamodule.train_dataloader), 645 num_training_steps_per_epoch=len(datamodule.train_dataloader),
601 gradient_accumulation_steps=args.gradient_accumulation_steps, 646 gradient_accumulation_steps=args.gradient_accumulation_steps,
602 train_epochs=args.ti_num_train_epochs, 647 train_epochs=args.ti_num_train_epochs,
603 warmup_epochs=args.ti_num_train_epochs // 4,
604 ) 648 )
605 649
606 trainer( 650 trainer(
@@ -608,10 +652,11 @@ def main():
608 project="textual_inversion", 652 project="textual_inversion",
609 train_dataloader=datamodule.train_dataloader, 653 train_dataloader=datamodule.train_dataloader,
610 val_dataloader=datamodule.val_dataloader, 654 val_dataloader=datamodule.val_dataloader,
655 seed=args.seed,
611 optimizer=optimizer, 656 optimizer=optimizer,
612 lr_scheduler=lr_scheduler, 657 lr_scheduler=lr_scheduler,
613 num_train_epochs=args.ti_num_train_epochs, 658 num_train_epochs=args.ti_num_train_epochs,
614 sample_frequency=2, 659 sample_frequency=args.ti_num_train_epochs // 5,
615 checkpoint_frequency=9999999, 660 checkpoint_frequency=9999999,
616 # -- 661 # --
617 tokenizer=tokenizer, 662 tokenizer=tokenizer,
@@ -637,7 +682,7 @@ def main():
637 cur_dir = output_dir.joinpath("2-db") 682 cur_dir = output_dir.joinpath("2-db")
638 cur_dir.mkdir(parents=True, exist_ok=True) 683 cur_dir.mkdir(parents=True, exist_ok=True)
639 684
640 args.seed = (args.seed + 28635) >> 32 685 args.seed = seed_generator.seed()
641 686
642 datamodule = VlpnDataModule( 687 datamodule = VlpnDataModule(
643 data_file=args.train_data_file, 688 data_file=args.train_data_file,
@@ -654,7 +699,8 @@ def main():
654 shuffle=not args.no_tag_shuffle, 699 shuffle=not args.no_tag_shuffle,
655 template_key=args.train_data_template, 700 template_key=args.train_data_template,
656 valid_set_size=args.valid_set_size, 701 valid_set_size=args.valid_set_size,
657 valid_set_repeat=args.valid_set_repeat, 702 train_set_pad=args.train_set_pad,
703 valid_set_pad=args.valid_set_pad,
658 seed=args.seed, 704 seed=args.seed,
659 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 705 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
660 dtype=weight_dtype 706 dtype=weight_dtype
@@ -697,6 +743,7 @@ def main():
697 project="dreambooth", 743 project="dreambooth",
698 train_dataloader=datamodule.train_dataloader, 744 train_dataloader=datamodule.train_dataloader,
699 val_dataloader=datamodule.val_dataloader, 745 val_dataloader=datamodule.val_dataloader,
746 seed=args.seed,
700 optimizer=optimizer, 747 optimizer=optimizer,
701 lr_scheduler=lr_scheduler, 748 lr_scheduler=lr_scheduler,
702 num_train_epochs=args.num_train_epochs, 749 num_train_epochs=args.num_train_epochs,