diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 71 |
1 files changed, 59 insertions, 12 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index a9fbbbd..1dc41b1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -55,6 +55,18 @@ def parse_args(): | |||
55 | default="template", | 55 | default="template", |
56 | ) | 56 | ) |
57 | parser.add_argument( | 57 | parser.add_argument( |
58 | "--train_set_pad", | ||
59 | type=int, | ||
60 | default=None, | ||
61 | help="The number to fill train dataset items up to." | ||
62 | ) | ||
63 | parser.add_argument( | ||
64 | "--valid_set_pad", | ||
65 | type=int, | ||
66 | default=None, | ||
67 | help="The number to fill validation dataset items up to." | ||
68 | ) | ||
69 | parser.add_argument( | ||
58 | "--project", | 70 | "--project", |
59 | type=str, | 71 | type=str, |
60 | default=None, | 72 | default=None, |
@@ -188,11 +200,23 @@ def parse_args(): | |||
188 | default=100 | 200 | default=100 |
189 | ) | 201 | ) |
190 | parser.add_argument( | 202 | parser.add_argument( |
203 | "--ti_data_template", | ||
204 | type=str, | ||
205 | nargs='*', | ||
206 | default=[], | ||
207 | ) | ||
208 | parser.add_argument( | ||
191 | "--ti_num_train_epochs", | 209 | "--ti_num_train_epochs", |
192 | type=int, | 210 | type=int, |
193 | default=10 | 211 | default=10 |
194 | ) | 212 | ) |
195 | parser.add_argument( | 213 | parser.add_argument( |
214 | "--ti_batch_size", | ||
215 | type=int, | ||
216 | default=1, | ||
217 | help="Batch size (per device) for the training dataloader." | ||
218 | ) | ||
219 | parser.add_argument( | ||
196 | "--max_train_steps", | 220 | "--max_train_steps", |
197 | type=int, | 221 | type=int, |
198 | default=None, | 222 | default=None, |
@@ -458,6 +482,12 @@ def parse_args(): | |||
458 | if len(args.placeholder_tokens) != len(args.num_vectors): | 482 | if len(args.placeholder_tokens) != len(args.num_vectors): |
459 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
460 | 484 | ||
485 | if isinstance(args.ti_data_template, str): | ||
486 | args.ti_data_template = [args.ti_data_template] | ||
487 | |||
488 | if len(args.ti_data_template) == 0: | ||
489 | raise ValueError("You must specify --ti_data_template") | ||
490 | |||
461 | if isinstance(args.collection, str): | 491 | if isinstance(args.collection, str): |
462 | args.collection = [args.collection] | 492 | args.collection = [args.collection] |
463 | 493 | ||
@@ -491,6 +521,8 @@ def main(): | |||
491 | 521 | ||
492 | set_seed(args.seed) | 522 | set_seed(args.seed) |
493 | 523 | ||
524 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
525 | |||
494 | save_args(output_dir, args) | 526 | save_args(output_dir, args) |
495 | 527 | ||
496 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 528 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
@@ -512,6 +544,8 @@ def main(): | |||
512 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 544 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
513 | raise ValueError("--embeddings_dir must point to an existing directory") | 545 | raise ValueError("--embeddings_dir must point to an existing directory") |
514 | 546 | ||
547 | embeddings.persist() | ||
548 | |||
515 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 549 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
516 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 550 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
517 | 551 | ||
@@ -545,7 +579,6 @@ def main(): | |||
545 | vae=vae, | 579 | vae=vae, |
546 | noise_scheduler=noise_scheduler, | 580 | noise_scheduler=noise_scheduler, |
547 | dtype=weight_dtype, | 581 | dtype=weight_dtype, |
548 | seed=args.seed, | ||
549 | with_prior_preservation=args.num_class_images != 0, | 582 | with_prior_preservation=args.num_class_images != 0, |
550 | prior_loss_weight=args.prior_loss_weight, | 583 | prior_loss_weight=args.prior_loss_weight, |
551 | ) | 584 | ) |
@@ -557,13 +590,17 @@ def main(): | |||
557 | cur_dir = output_dir.joinpath("1-ti") | 590 | cur_dir = output_dir.joinpath("1-ti") |
558 | cur_dir.mkdir(parents=True, exist_ok=True) | 591 | cur_dir.mkdir(parents=True, exist_ok=True) |
559 | 592 | ||
560 | for placeholder_token, initializer_token, num_vectors in zip(args.placeholder_tokens, args.initializer_tokens, args.num_vectors): | 593 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
561 | print(f"Phase 1.1: {placeholder_token} ({num_vectors}) ({initializer_token})") | 594 | range(len(args.placeholder_tokens)), |
562 | 595 | args.placeholder_tokens, | |
596 | args.initializer_tokens, | ||
597 | args.num_vectors, | ||
598 | args.ti_data_template | ||
599 | ): | ||
563 | cur_subdir = cur_dir.joinpath(placeholder_token) | 600 | cur_subdir = cur_dir.joinpath(placeholder_token) |
564 | cur_subdir.mkdir(parents=True, exist_ok=True) | 601 | cur_subdir.mkdir(parents=True, exist_ok=True) |
565 | 602 | ||
566 | placeholder_token_ids, _ = add_placeholder_tokens( | 603 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
567 | tokenizer=tokenizer, | 604 | tokenizer=tokenizer, |
568 | embeddings=embeddings, | 605 | embeddings=embeddings, |
569 | placeholder_tokens=[placeholder_token], | 606 | placeholder_tokens=[placeholder_token], |
@@ -571,17 +608,23 @@ def main(): | |||
571 | num_vectors=[num_vectors] | 608 | num_vectors=[num_vectors] |
572 | ) | 609 | ) |
573 | 610 | ||
611 | print( | ||
612 | f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") | ||
613 | |||
614 | args.seed = seed_generator.seed() | ||
615 | |||
574 | datamodule = VlpnDataModule( | 616 | datamodule = VlpnDataModule( |
575 | data_file=args.train_data_file, | 617 | data_file=args.train_data_file, |
576 | batch_size=args.train_batch_size, | 618 | batch_size=args.ti_batch_size, |
577 | tokenizer=tokenizer, | 619 | tokenizer=tokenizer, |
578 | class_subdir=args.class_image_dir, | 620 | class_subdir=args.class_image_dir, |
579 | num_class_images=args.num_class_images, | 621 | num_class_images=args.num_class_images, |
580 | size=args.resolution, | 622 | size=args.resolution, |
581 | shuffle=not args.no_tag_shuffle, | 623 | shuffle=not args.no_tag_shuffle, |
582 | template_key=args.train_data_template, | 624 | template_key=data_template, |
583 | valid_set_size=1, | 625 | valid_set_size=1, |
584 | valid_set_repeat=args.valid_set_repeat, | 626 | train_set_pad=args.train_set_pad, |
627 | valid_set_pad=args.valid_set_pad, | ||
585 | seed=args.seed, | 628 | seed=args.seed, |
586 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), | 629 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), |
587 | dtype=weight_dtype | 630 | dtype=weight_dtype |
@@ -591,7 +634,9 @@ def main(): | |||
591 | optimizer = optimizer_class( | 634 | optimizer = optimizer_class( |
592 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 635 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
593 | lr=args.ti_learning_rate, | 636 | lr=args.ti_learning_rate, |
637 | betas=(args.adam_beta1, args.adam_beta2), | ||
594 | weight_decay=0.0, | 638 | weight_decay=0.0, |
639 | eps=args.adam_epsilon, | ||
595 | ) | 640 | ) |
596 | 641 | ||
597 | lr_scheduler = get_scheduler( | 642 | lr_scheduler = get_scheduler( |
@@ -600,7 +645,6 @@ def main(): | |||
600 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 645 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
601 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 646 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
602 | train_epochs=args.ti_num_train_epochs, | 647 | train_epochs=args.ti_num_train_epochs, |
603 | warmup_epochs=args.ti_num_train_epochs // 4, | ||
604 | ) | 648 | ) |
605 | 649 | ||
606 | trainer( | 650 | trainer( |
@@ -608,10 +652,11 @@ def main(): | |||
608 | project="textual_inversion", | 652 | project="textual_inversion", |
609 | train_dataloader=datamodule.train_dataloader, | 653 | train_dataloader=datamodule.train_dataloader, |
610 | val_dataloader=datamodule.val_dataloader, | 654 | val_dataloader=datamodule.val_dataloader, |
655 | seed=args.seed, | ||
611 | optimizer=optimizer, | 656 | optimizer=optimizer, |
612 | lr_scheduler=lr_scheduler, | 657 | lr_scheduler=lr_scheduler, |
613 | num_train_epochs=args.ti_num_train_epochs, | 658 | num_train_epochs=args.ti_num_train_epochs, |
614 | sample_frequency=2, | 659 | sample_frequency=args.ti_num_train_epochs // 5, |
615 | checkpoint_frequency=9999999, | 660 | checkpoint_frequency=9999999, |
616 | # -- | 661 | # -- |
617 | tokenizer=tokenizer, | 662 | tokenizer=tokenizer, |
@@ -637,7 +682,7 @@ def main(): | |||
637 | cur_dir = output_dir.joinpath("2-db") | 682 | cur_dir = output_dir.joinpath("2-db") |
638 | cur_dir.mkdir(parents=True, exist_ok=True) | 683 | cur_dir.mkdir(parents=True, exist_ok=True) |
639 | 684 | ||
640 | args.seed = (args.seed + 28635) >> 32 | 685 | args.seed = seed_generator.seed() |
641 | 686 | ||
642 | datamodule = VlpnDataModule( | 687 | datamodule = VlpnDataModule( |
643 | data_file=args.train_data_file, | 688 | data_file=args.train_data_file, |
@@ -654,7 +699,8 @@ def main(): | |||
654 | shuffle=not args.no_tag_shuffle, | 699 | shuffle=not args.no_tag_shuffle, |
655 | template_key=args.train_data_template, | 700 | template_key=args.train_data_template, |
656 | valid_set_size=args.valid_set_size, | 701 | valid_set_size=args.valid_set_size, |
657 | valid_set_repeat=args.valid_set_repeat, | 702 | train_set_pad=args.train_set_pad, |
703 | valid_set_pad=args.valid_set_pad, | ||
658 | seed=args.seed, | 704 | seed=args.seed, |
659 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 705 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
660 | dtype=weight_dtype | 706 | dtype=weight_dtype |
@@ -697,6 +743,7 @@ def main(): | |||
697 | project="dreambooth", | 743 | project="dreambooth", |
698 | train_dataloader=datamodule.train_dataloader, | 744 | train_dataloader=datamodule.train_dataloader, |
699 | val_dataloader=datamodule.val_dataloader, | 745 | val_dataloader=datamodule.val_dataloader, |
746 | seed=args.seed, | ||
700 | optimizer=optimizer, | 747 | optimizer=optimizer, |
701 | lr_scheduler=lr_scheduler, | 748 | lr_scheduler=lr_scheduler, |
702 | num_train_epochs=args.num_train_epochs, | 749 | num_train_epochs=args.num_train_epochs, |