diff options
author | Volpeon <git@volpeon.ink> | 2023-01-15 21:06:16 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-15 21:06:16 +0100 |
commit | 632ce00b54ffeacfc18f44f10827f167ab3ac37c (patch) | |
tree | ecf58df2b176d3c7d1583136bf453ed24de8d7f3 /train_ti.py | |
parent | Fixed Conda env (diff) | |
download | textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.tar.gz textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.tar.bz2 textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.zip |
Restored functional trainer
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 82 |
1 files changed, 21 insertions, 61 deletions
diff --git a/train_ti.py b/train_ti.py index 4bac736..77dec12 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -10,15 +10,13 @@ import torch.utils.checkpoint | |||
10 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
11 | from accelerate.logging import get_logger | 11 | from accelerate.logging import get_logger |
12 | from accelerate.utils import LoggerType, set_seed | 12 | from accelerate.utils import LoggerType, set_seed |
13 | import matplotlib.pyplot as plt | ||
14 | from slugify import slugify | 13 | from slugify import slugify |
15 | 14 | ||
16 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
17 | from data.csv import VlpnDataModule, VlpnDataItem | 16 | from data.csv import VlpnDataModule, VlpnDataItem |
18 | from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
19 | from training.strategy.ti import textual_inversion_strategy | 18 | from training.strategy.ti import textual_inversion_strategy |
20 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
21 | from training.lr import LRFinder | ||
22 | from training.util import save_args | 20 | from training.util import save_args |
23 | 21 | ||
24 | logger = get_logger(__name__) | 22 | logger = get_logger(__name__) |
@@ -644,23 +642,33 @@ def main(): | |||
644 | warmup_epochs=args.lr_warmup_epochs, | 642 | warmup_epochs=args.lr_warmup_epochs, |
645 | ) | 643 | ) |
646 | 644 | ||
647 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 645 | trainer = partial( |
648 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 646 | train, |
649 | ) | ||
650 | |||
651 | vae.to(accelerator.device, dtype=weight_dtype) | ||
652 | |||
653 | callbacks = textual_inversion_strategy( | ||
654 | accelerator=accelerator, | 647 | accelerator=accelerator, |
655 | unet=unet, | 648 | unet=unet, |
656 | text_encoder=text_encoder, | 649 | text_encoder=text_encoder, |
657 | tokenizer=tokenizer, | ||
658 | vae=vae, | 650 | vae=vae, |
659 | sample_scheduler=sample_scheduler, | 651 | noise_scheduler=noise_scheduler, |
660 | train_dataloader=train_dataloader, | 652 | train_dataloader=train_dataloader, |
661 | val_dataloader=val_dataloader, | 653 | val_dataloader=val_dataloader, |
662 | output_dir=output_dir, | 654 | dtype=weight_dtype, |
663 | seed=args.seed, | 655 | seed=args.seed, |
656 | callbacks_fn=textual_inversion_strategy | ||
657 | ) | ||
658 | |||
659 | trainer( | ||
660 | optimizer=optimizer, | ||
661 | lr_scheduler=lr_scheduler, | ||
662 | num_train_epochs=args.num_train_epochs, | ||
663 | sample_frequency=args.sample_frequency, | ||
664 | checkpoint_frequency=args.checkpoint_frequency, | ||
665 | global_step_offset=global_step_offset, | ||
666 | with_prior_preservation=args.num_class_images != 0, | ||
667 | prior_loss_weight=args.prior_loss_weight, | ||
668 | # -- | ||
669 | tokenizer=tokenizer, | ||
670 | sample_scheduler=sample_scheduler, | ||
671 | output_dir=output_dir, | ||
664 | placeholder_tokens=args.placeholder_tokens, | 672 | placeholder_tokens=args.placeholder_tokens, |
665 | placeholder_token_ids=placeholder_token_ids, | 673 | placeholder_token_ids=placeholder_token_ids, |
666 | learning_rate=args.learning_rate, | 674 | learning_rate=args.learning_rate, |
@@ -679,54 +687,6 @@ def main(): | |||
679 | sample_image_size=args.sample_image_size, | 687 | sample_image_size=args.sample_image_size, |
680 | ) | 688 | ) |
681 | 689 | ||
682 | for model in (unet, text_encoder, vae): | ||
683 | model.requires_grad_(False) | ||
684 | model.eval() | ||
685 | |||
686 | callbacks.on_prepare() | ||
687 | |||
688 | loss_step_ = partial( | ||
689 | loss_step, | ||
690 | vae, | ||
691 | noise_scheduler, | ||
692 | unet, | ||
693 | text_encoder, | ||
694 | args.num_class_images != 0, | ||
695 | args.prior_loss_weight, | ||
696 | args.seed, | ||
697 | ) | ||
698 | |||
699 | if args.find_lr: | ||
700 | lr_finder = LRFinder( | ||
701 | accelerator=accelerator, | ||
702 | optimizer=optimizer, | ||
703 | train_dataloader=train_dataloader, | ||
704 | val_dataloader=val_dataloader, | ||
705 | callbacks=callbacks, | ||
706 | ) | ||
707 | lr_finder.run(num_epochs=100, end_lr=1e3) | ||
708 | |||
709 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | ||
710 | plt.close() | ||
711 | else: | ||
712 | if accelerator.is_main_process: | ||
713 | accelerator.init_trackers("textual_inversion") | ||
714 | |||
715 | train_loop( | ||
716 | accelerator=accelerator, | ||
717 | optimizer=optimizer, | ||
718 | lr_scheduler=lr_scheduler, | ||
719 | train_dataloader=train_dataloader, | ||
720 | val_dataloader=val_dataloader, | ||
721 | loss_step=loss_step_, | ||
722 | sample_frequency=args.sample_frequency, | ||
723 | checkpoint_frequency=args.checkpoint_frequency, | ||
724 | global_step_offset=global_step_offset, | ||
725 | callbacks=callbacks, | ||
726 | ) | ||
727 | |||
728 | accelerator.end_training() | ||
729 | |||
730 | 690 | ||
731 | if __name__ == "__main__": | 691 | if __name__ == "__main__": |
732 | main() | 692 | main() |