diff options
author | Volpeon <git@volpeon.ink> | 2023-01-15 12:33:52 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-15 12:33:52 +0100 |
commit | 59bf501198d7ff6c0c03c45e92adef14069d5ac6 (patch) | |
tree | aae4c7204b4f04bf2146408fb88892071840a05d /train_ti.py | |
parent | Removed unused code, put training callbacks in dataclass (diff) | |
download | textual-inversion-diff-59bf501198d7ff6c0c03c45e92adef14069d5ac6.tar.gz textual-inversion-diff-59bf501198d7ff6c0c03c45e92adef14069d5ac6.tar.bz2 textual-inversion-diff-59bf501198d7ff6c0c03c45e92adef14069d5ac6.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 74 |
1 files changed, 38 insertions, 36 deletions
diff --git a/train_ti.py b/train_ti.py index 3c9810f..4bac736 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -15,11 +15,11 @@ from slugify import slugify | |||
15 | 15 | ||
16 | from util import load_config, load_embeddings_from_dir | 16 | from util import load_config, load_embeddings_from_dir |
17 | from data.csv import VlpnDataModule, VlpnDataItem | 17 | from data.csv import VlpnDataModule, VlpnDataItem |
18 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 18 | from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models |
19 | from training.strategy.ti import textual_inversion_strategy | 19 | from training.strategy.ti import textual_inversion_strategy |
20 | from training.optimization import get_scheduler | 20 | from training.optimization import get_scheduler |
21 | from training.lr import LRFinder | 21 | from training.lr import LRFinder |
22 | from training.util import EMAModel, save_args | 22 | from training.util import save_args |
23 | 23 | ||
24 | logger = get_logger(__name__) | 24 | logger = get_logger(__name__) |
25 | 25 | ||
@@ -82,7 +82,7 @@ def parse_args(): | |||
82 | parser.add_argument( | 82 | parser.add_argument( |
83 | "--num_class_images", | 83 | "--num_class_images", |
84 | type=int, | 84 | type=int, |
85 | default=1, | 85 | default=0, |
86 | help="How many class images to generate." | 86 | help="How many class images to generate." |
87 | ) | 87 | ) |
88 | parser.add_argument( | 88 | parser.add_argument( |
@@ -398,7 +398,7 @@ def parse_args(): | |||
398 | ) | 398 | ) |
399 | parser.add_argument( | 399 | parser.add_argument( |
400 | "--emb_decay_factor", | 400 | "--emb_decay_factor", |
401 | default=0, | 401 | default=1, |
402 | type=float, | 402 | type=float, |
403 | help="Embedding decay factor." | 403 | help="Embedding decay factor." |
404 | ) | 404 | ) |
@@ -540,16 +540,6 @@ def main(): | |||
540 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | 540 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) |
541 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | 541 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") |
542 | 542 | ||
543 | if args.use_ema: | ||
544 | ema_embeddings = EMAModel( | ||
545 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
546 | inv_gamma=args.ema_inv_gamma, | ||
547 | power=args.ema_power, | ||
548 | max_value=args.ema_max_decay, | ||
549 | ) | ||
550 | else: | ||
551 | ema_embeddings = None | ||
552 | |||
553 | if args.scale_lr: | 543 | if args.scale_lr: |
554 | args.learning_rate = ( | 544 | args.learning_rate = ( |
555 | args.learning_rate * args.gradient_accumulation_steps * | 545 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -654,23 +644,13 @@ def main(): | |||
654 | warmup_epochs=args.lr_warmup_epochs, | 644 | warmup_epochs=args.lr_warmup_epochs, |
655 | ) | 645 | ) |
656 | 646 | ||
657 | if args.use_ema: | 647 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
658 | ema_embeddings.to(accelerator.device) | 648 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
659 | |||
660 | trainer = partial( | ||
661 | train, | ||
662 | accelerator=accelerator, | ||
663 | vae=vae, | ||
664 | unet=unet, | ||
665 | text_encoder=text_encoder, | ||
666 | noise_scheduler=noise_scheduler, | ||
667 | train_dataloader=train_dataloader, | ||
668 | val_dataloader=val_dataloader, | ||
669 | dtype=weight_dtype, | ||
670 | seed=args.seed, | ||
671 | ) | 649 | ) |
672 | 650 | ||
673 | strategy = textual_inversion_strategy( | 651 | vae.to(accelerator.device, dtype=weight_dtype) |
652 | |||
653 | callbacks = textual_inversion_strategy( | ||
674 | accelerator=accelerator, | 654 | accelerator=accelerator, |
675 | unet=unet, | 655 | unet=unet, |
676 | text_encoder=text_encoder, | 656 | text_encoder=text_encoder, |
@@ -679,7 +659,6 @@ def main(): | |||
679 | sample_scheduler=sample_scheduler, | 659 | sample_scheduler=sample_scheduler, |
680 | train_dataloader=train_dataloader, | 660 | train_dataloader=train_dataloader, |
681 | val_dataloader=val_dataloader, | 661 | val_dataloader=val_dataloader, |
682 | dtype=weight_dtype, | ||
683 | output_dir=output_dir, | 662 | output_dir=output_dir, |
684 | seed=args.seed, | 663 | seed=args.seed, |
685 | placeholder_tokens=args.placeholder_tokens, | 664 | placeholder_tokens=args.placeholder_tokens, |
@@ -700,31 +679,54 @@ def main(): | |||
700 | sample_image_size=args.sample_image_size, | 679 | sample_image_size=args.sample_image_size, |
701 | ) | 680 | ) |
702 | 681 | ||
682 | for model in (unet, text_encoder, vae): | ||
683 | model.requires_grad_(False) | ||
684 | model.eval() | ||
685 | |||
686 | callbacks.on_prepare() | ||
687 | |||
688 | loss_step_ = partial( | ||
689 | loss_step, | ||
690 | vae, | ||
691 | noise_scheduler, | ||
692 | unet, | ||
693 | text_encoder, | ||
694 | args.num_class_images != 0, | ||
695 | args.prior_loss_weight, | ||
696 | args.seed, | ||
697 | ) | ||
698 | |||
703 | if args.find_lr: | 699 | if args.find_lr: |
704 | lr_finder = LRFinder( | 700 | lr_finder = LRFinder( |
705 | accelerator=accelerator, | 701 | accelerator=accelerator, |
706 | optimizer=optimizer, | 702 | optimizer=optimizer, |
707 | model=text_encoder, | ||
708 | train_dataloader=train_dataloader, | 703 | train_dataloader=train_dataloader, |
709 | val_dataloader=val_dataloader, | 704 | val_dataloader=val_dataloader, |
710 | **strategy, | 705 | callbacks=callbacks, |
711 | ) | 706 | ) |
712 | lr_finder.run(num_epochs=100, end_lr=1e3) | 707 | lr_finder.run(num_epochs=100, end_lr=1e3) |
713 | 708 | ||
714 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | 709 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
715 | plt.close() | 710 | plt.close() |
716 | else: | 711 | else: |
717 | trainer( | 712 | if accelerator.is_main_process: |
713 | accelerator.init_trackers("textual_inversion") | ||
714 | |||
715 | train_loop( | ||
716 | accelerator=accelerator, | ||
718 | optimizer=optimizer, | 717 | optimizer=optimizer, |
719 | lr_scheduler=lr_scheduler, | 718 | lr_scheduler=lr_scheduler, |
720 | num_train_epochs=args.num_train_epochs, | 719 | train_dataloader=train_dataloader, |
720 | val_dataloader=val_dataloader, | ||
721 | loss_step=loss_step_, | ||
721 | sample_frequency=args.sample_frequency, | 722 | sample_frequency=args.sample_frequency, |
722 | checkpoint_frequency=args.checkpoint_frequency, | 723 | checkpoint_frequency=args.checkpoint_frequency, |
723 | global_step_offset=global_step_offset, | 724 | global_step_offset=global_step_offset, |
724 | prior_loss_weight=args.prior_loss_weight, | 725 | callbacks=callbacks, |
725 | callbacks=strategy, | ||
726 | ) | 726 | ) |
727 | 727 | ||
728 | accelerator.end_training() | ||
729 | |||
728 | 730 | ||
729 | if __name__ == "__main__": | 731 | if __name__ == "__main__": |
730 | main() | 732 | main() |