diff options
author | Volpeon <git@volpeon.ink> | 2023-01-16 17:09:01 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-16 17:09:01 +0100 |
commit | 36440e48ce279872d6e736bcb1bf57d13da73a11 (patch) | |
tree | 8ba9593d8a887517c70b01932c137c9c3f759e8f /train_dreambooth.py | |
parent | More training adjustments (diff) | |
download | textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.tar.gz textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.tar.bz2 textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.zip |
Moved multi-TI code from Dreambooth to TI script
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 135 |
1 files changed, 2 insertions, 133 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1dc41b1..6511f9b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -200,23 +200,6 @@ def parse_args(): | |||
200 | default=100 | 200 | default=100 |
201 | ) | 201 | ) |
202 | parser.add_argument( | 202 | parser.add_argument( |
203 | "--ti_data_template", | ||
204 | type=str, | ||
205 | nargs='*', | ||
206 | default=[], | ||
207 | ) | ||
208 | parser.add_argument( | ||
209 | "--ti_num_train_epochs", | ||
210 | type=int, | ||
211 | default=10 | ||
212 | ) | ||
213 | parser.add_argument( | ||
214 | "--ti_batch_size", | ||
215 | type=int, | ||
216 | default=1, | ||
217 | help="Batch size (per device) for the training dataloader." | ||
218 | ) | ||
219 | parser.add_argument( | ||
220 | "--max_train_steps", | 203 | "--max_train_steps", |
221 | type=int, | 204 | type=int, |
222 | default=None, | 205 | default=None, |
@@ -245,12 +228,6 @@ def parse_args(): | |||
245 | help="Initial learning rate (after the potential warmup period) to use.", | 228 | help="Initial learning rate (after the potential warmup period) to use.", |
246 | ) | 229 | ) |
247 | parser.add_argument( | 230 | parser.add_argument( |
248 | "--ti_learning_rate", | ||
249 | type=float, | ||
250 | default=1e-2, | ||
251 | help="Initial learning rate (after the potential warmup period) to use.", | ||
252 | ) | ||
253 | parser.add_argument( | ||
254 | "--scale_lr", | 231 | "--scale_lr", |
255 | action="store_true", | 232 | action="store_true", |
256 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 233 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
@@ -482,12 +459,6 @@ def parse_args(): | |||
482 | if len(args.placeholder_tokens) != len(args.num_vectors): | 459 | if len(args.placeholder_tokens) != len(args.num_vectors): |
483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 460 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
484 | 461 | ||
485 | if isinstance(args.ti_data_template, str): | ||
486 | args.ti_data_template = [args.ti_data_template] | ||
487 | |||
488 | if len(args.ti_data_template) == 0: | ||
489 | raise ValueError("You must specify --ti_data_template") | ||
490 | |||
491 | if isinstance(args.collection, str): | 462 | if isinstance(args.collection, str): |
492 | args.collection = [args.collection] | 463 | args.collection = [args.collection] |
493 | 464 | ||
@@ -521,8 +492,6 @@ def main(): | |||
521 | 492 | ||
522 | set_seed(args.seed) | 493 | set_seed(args.seed) |
523 | 494 | ||
524 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
525 | |||
526 | save_args(output_dir, args) | 495 | save_args(output_dir, args) |
527 | 496 | ||
528 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 497 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
@@ -583,107 +552,6 @@ def main(): | |||
583 | prior_loss_weight=args.prior_loss_weight, | 552 | prior_loss_weight=args.prior_loss_weight, |
584 | ) | 553 | ) |
585 | 554 | ||
586 | # Initial TI | ||
587 | |||
588 | print("Phase 1: Textual Inversion") | ||
589 | |||
590 | cur_dir = output_dir.joinpath("1-ti") | ||
591 | cur_dir.mkdir(parents=True, exist_ok=True) | ||
592 | |||
593 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | ||
594 | range(len(args.placeholder_tokens)), | ||
595 | args.placeholder_tokens, | ||
596 | args.initializer_tokens, | ||
597 | args.num_vectors, | ||
598 | args.ti_data_template | ||
599 | ): | ||
600 | cur_subdir = cur_dir.joinpath(placeholder_token) | ||
601 | cur_subdir.mkdir(parents=True, exist_ok=True) | ||
602 | |||
603 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
604 | tokenizer=tokenizer, | ||
605 | embeddings=embeddings, | ||
606 | placeholder_tokens=[placeholder_token], | ||
607 | initializer_tokens=[initializer_token], | ||
608 | num_vectors=[num_vectors] | ||
609 | ) | ||
610 | |||
611 | print( | ||
612 | f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") | ||
613 | |||
614 | args.seed = seed_generator.seed() | ||
615 | |||
616 | datamodule = VlpnDataModule( | ||
617 | data_file=args.train_data_file, | ||
618 | batch_size=args.ti_batch_size, | ||
619 | tokenizer=tokenizer, | ||
620 | class_subdir=args.class_image_dir, | ||
621 | num_class_images=args.num_class_images, | ||
622 | size=args.resolution, | ||
623 | shuffle=not args.no_tag_shuffle, | ||
624 | template_key=data_template, | ||
625 | valid_set_size=1, | ||
626 | train_set_pad=args.train_set_pad, | ||
627 | valid_set_pad=args.valid_set_pad, | ||
628 | seed=args.seed, | ||
629 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), | ||
630 | dtype=weight_dtype | ||
631 | ) | ||
632 | datamodule.setup() | ||
633 | |||
634 | optimizer = optimizer_class( | ||
635 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
636 | lr=args.ti_learning_rate, | ||
637 | betas=(args.adam_beta1, args.adam_beta2), | ||
638 | weight_decay=0.0, | ||
639 | eps=args.adam_epsilon, | ||
640 | ) | ||
641 | |||
642 | lr_scheduler = get_scheduler( | ||
643 | "one_cycle", | ||
644 | optimizer=optimizer, | ||
645 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
646 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
647 | train_epochs=args.ti_num_train_epochs, | ||
648 | ) | ||
649 | |||
650 | trainer( | ||
651 | callbacks_fn=textual_inversion_strategy, | ||
652 | project="textual_inversion", | ||
653 | train_dataloader=datamodule.train_dataloader, | ||
654 | val_dataloader=datamodule.val_dataloader, | ||
655 | seed=args.seed, | ||
656 | optimizer=optimizer, | ||
657 | lr_scheduler=lr_scheduler, | ||
658 | num_train_epochs=args.ti_num_train_epochs, | ||
659 | sample_frequency=args.ti_num_train_epochs // 5, | ||
660 | checkpoint_frequency=9999999, | ||
661 | # -- | ||
662 | tokenizer=tokenizer, | ||
663 | sample_scheduler=sample_scheduler, | ||
664 | output_dir=cur_subdir, | ||
665 | placeholder_tokens=[placeholder_token], | ||
666 | placeholder_token_ids=placeholder_token_ids, | ||
667 | learning_rate=args.ti_learning_rate, | ||
668 | gradient_checkpointing=args.gradient_checkpointing, | ||
669 | use_emb_decay=True, | ||
670 | sample_batch_size=args.sample_batch_size, | ||
671 | sample_num_batches=args.sample_batches, | ||
672 | sample_num_steps=args.sample_steps, | ||
673 | sample_image_size=args.sample_image_size, | ||
674 | ) | ||
675 | |||
676 | embeddings.persist() | ||
677 | |||
678 | # Dreambooth | ||
679 | |||
680 | print("Phase 2: Dreambooth") | ||
681 | |||
682 | cur_dir = output_dir.joinpath("2-db") | ||
683 | cur_dir.mkdir(parents=True, exist_ok=True) | ||
684 | |||
685 | args.seed = seed_generator.seed() | ||
686 | |||
687 | datamodule = VlpnDataModule( | 555 | datamodule = VlpnDataModule( |
688 | data_file=args.train_data_file, | 556 | data_file=args.train_data_file, |
689 | batch_size=args.train_batch_size, | 557 | batch_size=args.train_batch_size, |
@@ -746,12 +614,13 @@ def main(): | |||
746 | seed=args.seed, | 614 | seed=args.seed, |
747 | optimizer=optimizer, | 615 | optimizer=optimizer, |
748 | lr_scheduler=lr_scheduler, | 616 | lr_scheduler=lr_scheduler, |
617 | prepare_unet=True, | ||
749 | num_train_epochs=args.num_train_epochs, | 618 | num_train_epochs=args.num_train_epochs, |
750 | sample_frequency=args.sample_frequency, | 619 | sample_frequency=args.sample_frequency, |
751 | # -- | 620 | # -- |
752 | tokenizer=tokenizer, | 621 | tokenizer=tokenizer, |
753 | sample_scheduler=sample_scheduler, | 622 | sample_scheduler=sample_scheduler, |
754 | output_dir=cur_dir, | 623 | output_dir=output_dir, |
755 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 624 | train_text_encoder_epochs=args.train_text_encoder_epochs, |
756 | max_grad_norm=args.max_grad_norm, | 625 | max_grad_norm=args.max_grad_norm, |
757 | use_ema=args.use_ema, | 626 | use_ema=args.use_ema, |