diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 82 |
1 files changed, 21 insertions, 61 deletions
diff --git a/train_ti.py b/train_ti.py index 4bac736..77dec12 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -10,15 +10,13 @@ import torch.utils.checkpoint | |||
| 10 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
| 11 | from accelerate.logging import get_logger | 11 | from accelerate.logging import get_logger |
| 12 | from accelerate.utils import LoggerType, set_seed | 12 | from accelerate.utils import LoggerType, set_seed |
| 13 | import matplotlib.pyplot as plt | ||
| 14 | from slugify import slugify | 13 | from slugify import slugify |
| 15 | 14 | ||
| 16 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
| 17 | from data.csv import VlpnDataModule, VlpnDataItem | 16 | from data.csv import VlpnDataModule, VlpnDataItem |
| 18 | from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
| 19 | from training.strategy.ti import textual_inversion_strategy | 18 | from training.strategy.ti import textual_inversion_strategy |
| 20 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
| 21 | from training.lr import LRFinder | ||
| 22 | from training.util import save_args | 20 | from training.util import save_args |
| 23 | 21 | ||
| 24 | logger = get_logger(__name__) | 22 | logger = get_logger(__name__) |
| @@ -644,23 +642,33 @@ def main(): | |||
| 644 | warmup_epochs=args.lr_warmup_epochs, | 642 | warmup_epochs=args.lr_warmup_epochs, |
| 645 | ) | 643 | ) |
| 646 | 644 | ||
| 647 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 645 | trainer = partial( |
| 648 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 646 | train, |
| 649 | ) | ||
| 650 | |||
| 651 | vae.to(accelerator.device, dtype=weight_dtype) | ||
| 652 | |||
| 653 | callbacks = textual_inversion_strategy( | ||
| 654 | accelerator=accelerator, | 647 | accelerator=accelerator, |
| 655 | unet=unet, | 648 | unet=unet, |
| 656 | text_encoder=text_encoder, | 649 | text_encoder=text_encoder, |
| 657 | tokenizer=tokenizer, | ||
| 658 | vae=vae, | 650 | vae=vae, |
| 659 | sample_scheduler=sample_scheduler, | 651 | noise_scheduler=noise_scheduler, |
| 660 | train_dataloader=train_dataloader, | 652 | train_dataloader=train_dataloader, |
| 661 | val_dataloader=val_dataloader, | 653 | val_dataloader=val_dataloader, |
| 662 | output_dir=output_dir, | 654 | dtype=weight_dtype, |
| 663 | seed=args.seed, | 655 | seed=args.seed, |
| 656 | callbacks_fn=textual_inversion_strategy | ||
| 657 | ) | ||
| 658 | |||
| 659 | trainer( | ||
| 660 | optimizer=optimizer, | ||
| 661 | lr_scheduler=lr_scheduler, | ||
| 662 | num_train_epochs=args.num_train_epochs, | ||
| 663 | sample_frequency=args.sample_frequency, | ||
| 664 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 665 | global_step_offset=global_step_offset, | ||
| 666 | with_prior_preservation=args.num_class_images != 0, | ||
| 667 | prior_loss_weight=args.prior_loss_weight, | ||
| 668 | # -- | ||
| 669 | tokenizer=tokenizer, | ||
| 670 | sample_scheduler=sample_scheduler, | ||
| 671 | output_dir=output_dir, | ||
| 664 | placeholder_tokens=args.placeholder_tokens, | 672 | placeholder_tokens=args.placeholder_tokens, |
| 665 | placeholder_token_ids=placeholder_token_ids, | 673 | placeholder_token_ids=placeholder_token_ids, |
| 666 | learning_rate=args.learning_rate, | 674 | learning_rate=args.learning_rate, |
| @@ -679,54 +687,6 @@ def main(): | |||
| 679 | sample_image_size=args.sample_image_size, | 687 | sample_image_size=args.sample_image_size, |
| 680 | ) | 688 | ) |
| 681 | 689 | ||
| 682 | for model in (unet, text_encoder, vae): | ||
| 683 | model.requires_grad_(False) | ||
| 684 | model.eval() | ||
| 685 | |||
| 686 | callbacks.on_prepare() | ||
| 687 | |||
| 688 | loss_step_ = partial( | ||
| 689 | loss_step, | ||
| 690 | vae, | ||
| 691 | noise_scheduler, | ||
| 692 | unet, | ||
| 693 | text_encoder, | ||
| 694 | args.num_class_images != 0, | ||
| 695 | args.prior_loss_weight, | ||
| 696 | args.seed, | ||
| 697 | ) | ||
| 698 | |||
| 699 | if args.find_lr: | ||
| 700 | lr_finder = LRFinder( | ||
| 701 | accelerator=accelerator, | ||
| 702 | optimizer=optimizer, | ||
| 703 | train_dataloader=train_dataloader, | ||
| 704 | val_dataloader=val_dataloader, | ||
| 705 | callbacks=callbacks, | ||
| 706 | ) | ||
| 707 | lr_finder.run(num_epochs=100, end_lr=1e3) | ||
| 708 | |||
| 709 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | ||
| 710 | plt.close() | ||
| 711 | else: | ||
| 712 | if accelerator.is_main_process: | ||
| 713 | accelerator.init_trackers("textual_inversion") | ||
| 714 | |||
| 715 | train_loop( | ||
| 716 | accelerator=accelerator, | ||
| 717 | optimizer=optimizer, | ||
| 718 | lr_scheduler=lr_scheduler, | ||
| 719 | train_dataloader=train_dataloader, | ||
| 720 | val_dataloader=val_dataloader, | ||
| 721 | loss_step=loss_step_, | ||
| 722 | sample_frequency=args.sample_frequency, | ||
| 723 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 724 | global_step_offset=global_step_offset, | ||
| 725 | callbacks=callbacks, | ||
| 726 | ) | ||
| 727 | |||
| 728 | accelerator.end_training() | ||
| 729 | |||
| 730 | 690 | ||
| 731 | if __name__ == "__main__": | 691 | if __name__ == "__main__": |
| 732 | main() | 692 | main() |
