diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 74 |
1 files changed, 38 insertions, 36 deletions
diff --git a/train_ti.py b/train_ti.py index 3c9810f..4bac736 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -15,11 +15,11 @@ from slugify import slugify | |||
| 15 | 15 | ||
| 16 | from util import load_config, load_embeddings_from_dir | 16 | from util import load_config, load_embeddings_from_dir |
| 17 | from data.csv import VlpnDataModule, VlpnDataItem | 17 | from data.csv import VlpnDataModule, VlpnDataItem |
| 18 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 18 | from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models |
| 19 | from training.strategy.ti import textual_inversion_strategy | 19 | from training.strategy.ti import textual_inversion_strategy |
| 20 | from training.optimization import get_scheduler | 20 | from training.optimization import get_scheduler |
| 21 | from training.lr import LRFinder | 21 | from training.lr import LRFinder |
| 22 | from training.util import EMAModel, save_args | 22 | from training.util import save_args |
| 23 | 23 | ||
| 24 | logger = get_logger(__name__) | 24 | logger = get_logger(__name__) |
| 25 | 25 | ||
| @@ -82,7 +82,7 @@ def parse_args(): | |||
| 82 | parser.add_argument( | 82 | parser.add_argument( |
| 83 | "--num_class_images", | 83 | "--num_class_images", |
| 84 | type=int, | 84 | type=int, |
| 85 | default=1, | 85 | default=0, |
| 86 | help="How many class images to generate." | 86 | help="How many class images to generate." |
| 87 | ) | 87 | ) |
| 88 | parser.add_argument( | 88 | parser.add_argument( |
| @@ -398,7 +398,7 @@ def parse_args(): | |||
| 398 | ) | 398 | ) |
| 399 | parser.add_argument( | 399 | parser.add_argument( |
| 400 | "--emb_decay_factor", | 400 | "--emb_decay_factor", |
| 401 | default=0, | 401 | default=1, |
| 402 | type=float, | 402 | type=float, |
| 403 | help="Embedding decay factor." | 403 | help="Embedding decay factor." |
| 404 | ) | 404 | ) |
| @@ -540,16 +540,6 @@ def main(): | |||
| 540 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | 540 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) |
| 541 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | 541 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") |
| 542 | 542 | ||
| 543 | if args.use_ema: | ||
| 544 | ema_embeddings = EMAModel( | ||
| 545 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 546 | inv_gamma=args.ema_inv_gamma, | ||
| 547 | power=args.ema_power, | ||
| 548 | max_value=args.ema_max_decay, | ||
| 549 | ) | ||
| 550 | else: | ||
| 551 | ema_embeddings = None | ||
| 552 | |||
| 553 | if args.scale_lr: | 543 | if args.scale_lr: |
| 554 | args.learning_rate = ( | 544 | args.learning_rate = ( |
| 555 | args.learning_rate * args.gradient_accumulation_steps * | 545 | args.learning_rate * args.gradient_accumulation_steps * |
| @@ -654,23 +644,13 @@ def main(): | |||
| 654 | warmup_epochs=args.lr_warmup_epochs, | 644 | warmup_epochs=args.lr_warmup_epochs, |
| 655 | ) | 645 | ) |
| 656 | 646 | ||
| 657 | if args.use_ema: | 647 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 658 | ema_embeddings.to(accelerator.device) | 648 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 659 | |||
| 660 | trainer = partial( | ||
| 661 | train, | ||
| 662 | accelerator=accelerator, | ||
| 663 | vae=vae, | ||
| 664 | unet=unet, | ||
| 665 | text_encoder=text_encoder, | ||
| 666 | noise_scheduler=noise_scheduler, | ||
| 667 | train_dataloader=train_dataloader, | ||
| 668 | val_dataloader=val_dataloader, | ||
| 669 | dtype=weight_dtype, | ||
| 670 | seed=args.seed, | ||
| 671 | ) | 649 | ) |
| 672 | 650 | ||
| 673 | strategy = textual_inversion_strategy( | 651 | vae.to(accelerator.device, dtype=weight_dtype) |
| 652 | |||
| 653 | callbacks = textual_inversion_strategy( | ||
| 674 | accelerator=accelerator, | 654 | accelerator=accelerator, |
| 675 | unet=unet, | 655 | unet=unet, |
| 676 | text_encoder=text_encoder, | 656 | text_encoder=text_encoder, |
| @@ -679,7 +659,6 @@ def main(): | |||
| 679 | sample_scheduler=sample_scheduler, | 659 | sample_scheduler=sample_scheduler, |
| 680 | train_dataloader=train_dataloader, | 660 | train_dataloader=train_dataloader, |
| 681 | val_dataloader=val_dataloader, | 661 | val_dataloader=val_dataloader, |
| 682 | dtype=weight_dtype, | ||
| 683 | output_dir=output_dir, | 662 | output_dir=output_dir, |
| 684 | seed=args.seed, | 663 | seed=args.seed, |
| 685 | placeholder_tokens=args.placeholder_tokens, | 664 | placeholder_tokens=args.placeholder_tokens, |
| @@ -700,31 +679,54 @@ def main(): | |||
| 700 | sample_image_size=args.sample_image_size, | 679 | sample_image_size=args.sample_image_size, |
| 701 | ) | 680 | ) |
| 702 | 681 | ||
| 682 | for model in (unet, text_encoder, vae): | ||
| 683 | model.requires_grad_(False) | ||
| 684 | model.eval() | ||
| 685 | |||
| 686 | callbacks.on_prepare() | ||
| 687 | |||
| 688 | loss_step_ = partial( | ||
| 689 | loss_step, | ||
| 690 | vae, | ||
| 691 | noise_scheduler, | ||
| 692 | unet, | ||
| 693 | text_encoder, | ||
| 694 | args.num_class_images != 0, | ||
| 695 | args.prior_loss_weight, | ||
| 696 | args.seed, | ||
| 697 | ) | ||
| 698 | |||
| 703 | if args.find_lr: | 699 | if args.find_lr: |
| 704 | lr_finder = LRFinder( | 700 | lr_finder = LRFinder( |
| 705 | accelerator=accelerator, | 701 | accelerator=accelerator, |
| 706 | optimizer=optimizer, | 702 | optimizer=optimizer, |
| 707 | model=text_encoder, | ||
| 708 | train_dataloader=train_dataloader, | 703 | train_dataloader=train_dataloader, |
| 709 | val_dataloader=val_dataloader, | 704 | val_dataloader=val_dataloader, |
| 710 | **strategy, | 705 | callbacks=callbacks, |
| 711 | ) | 706 | ) |
| 712 | lr_finder.run(num_epochs=100, end_lr=1e3) | 707 | lr_finder.run(num_epochs=100, end_lr=1e3) |
| 713 | 708 | ||
| 714 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | 709 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
| 715 | plt.close() | 710 | plt.close() |
| 716 | else: | 711 | else: |
| 717 | trainer( | 712 | if accelerator.is_main_process: |
| 713 | accelerator.init_trackers("textual_inversion") | ||
| 714 | |||
| 715 | train_loop( | ||
| 716 | accelerator=accelerator, | ||
| 718 | optimizer=optimizer, | 717 | optimizer=optimizer, |
| 719 | lr_scheduler=lr_scheduler, | 718 | lr_scheduler=lr_scheduler, |
| 720 | num_train_epochs=args.num_train_epochs, | 719 | train_dataloader=train_dataloader, |
| 720 | val_dataloader=val_dataloader, | ||
| 721 | loss_step=loss_step_, | ||
| 721 | sample_frequency=args.sample_frequency, | 722 | sample_frequency=args.sample_frequency, |
| 722 | checkpoint_frequency=args.checkpoint_frequency, | 723 | checkpoint_frequency=args.checkpoint_frequency, |
| 723 | global_step_offset=global_step_offset, | 724 | global_step_offset=global_step_offset, |
| 724 | prior_loss_weight=args.prior_loss_weight, | 725 | callbacks=callbacks, |
| 725 | callbacks=strategy, | ||
| 726 | ) | 726 | ) |
| 727 | 727 | ||
| 728 | accelerator.end_training() | ||
| 729 | |||
| 728 | 730 | ||
| 729 | if __name__ == "__main__": | 731 | if __name__ == "__main__": |
| 730 | main() | 732 | main() |
