summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py74
1 files changed, 38 insertions, 36 deletions
diff --git a/train_ti.py b/train_ti.py
index 3c9810f..4bac736 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -15,11 +15,11 @@ from slugify import slugify
15 15
16from util import load_config, load_embeddings_from_dir 16from util import load_config, load_embeddings_from_dir
17from data.csv import VlpnDataModule, VlpnDataItem 17from data.csv import VlpnDataModule, VlpnDataItem
18from training.functional import train, generate_class_images, add_placeholder_tokens, get_models 18from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models
19from training.strategy.ti import textual_inversion_strategy 19from training.strategy.ti import textual_inversion_strategy
20from training.optimization import get_scheduler 20from training.optimization import get_scheduler
21from training.lr import LRFinder 21from training.lr import LRFinder
22from training.util import EMAModel, save_args 22from training.util import save_args
23 23
24logger = get_logger(__name__) 24logger = get_logger(__name__)
25 25
@@ -82,7 +82,7 @@ def parse_args():
82 parser.add_argument( 82 parser.add_argument(
83 "--num_class_images", 83 "--num_class_images",
84 type=int, 84 type=int,
85 default=1, 85 default=0,
86 help="How many class images to generate." 86 help="How many class images to generate."
87 ) 87 )
88 parser.add_argument( 88 parser.add_argument(
@@ -398,7 +398,7 @@ def parse_args():
398 ) 398 )
399 parser.add_argument( 399 parser.add_argument(
400 "--emb_decay_factor", 400 "--emb_decay_factor",
401 default=0, 401 default=1,
402 type=float, 402 type=float,
403 help="Embedding decay factor." 403 help="Embedding decay factor."
404 ) 404 )
@@ -540,16 +540,6 @@ def main():
540 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) 540 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
541 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") 541 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
542 542
543 if args.use_ema:
544 ema_embeddings = EMAModel(
545 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
546 inv_gamma=args.ema_inv_gamma,
547 power=args.ema_power,
548 max_value=args.ema_max_decay,
549 )
550 else:
551 ema_embeddings = None
552
553 if args.scale_lr: 543 if args.scale_lr:
554 args.learning_rate = ( 544 args.learning_rate = (
555 args.learning_rate * args.gradient_accumulation_steps * 545 args.learning_rate * args.gradient_accumulation_steps *
@@ -654,23 +644,13 @@ def main():
654 warmup_epochs=args.lr_warmup_epochs, 644 warmup_epochs=args.lr_warmup_epochs,
655 ) 645 )
656 646
657 if args.use_ema: 647 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
658 ema_embeddings.to(accelerator.device) 648 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
659
660 trainer = partial(
661 train,
662 accelerator=accelerator,
663 vae=vae,
664 unet=unet,
665 text_encoder=text_encoder,
666 noise_scheduler=noise_scheduler,
667 train_dataloader=train_dataloader,
668 val_dataloader=val_dataloader,
669 dtype=weight_dtype,
670 seed=args.seed,
671 ) 649 )
672 650
673 strategy = textual_inversion_strategy( 651 vae.to(accelerator.device, dtype=weight_dtype)
652
653 callbacks = textual_inversion_strategy(
674 accelerator=accelerator, 654 accelerator=accelerator,
675 unet=unet, 655 unet=unet,
676 text_encoder=text_encoder, 656 text_encoder=text_encoder,
@@ -679,7 +659,6 @@ def main():
679 sample_scheduler=sample_scheduler, 659 sample_scheduler=sample_scheduler,
680 train_dataloader=train_dataloader, 660 train_dataloader=train_dataloader,
681 val_dataloader=val_dataloader, 661 val_dataloader=val_dataloader,
682 dtype=weight_dtype,
683 output_dir=output_dir, 662 output_dir=output_dir,
684 seed=args.seed, 663 seed=args.seed,
685 placeholder_tokens=args.placeholder_tokens, 664 placeholder_tokens=args.placeholder_tokens,
@@ -700,31 +679,54 @@ def main():
700 sample_image_size=args.sample_image_size, 679 sample_image_size=args.sample_image_size,
701 ) 680 )
702 681
682 for model in (unet, text_encoder, vae):
683 model.requires_grad_(False)
684 model.eval()
685
686 callbacks.on_prepare()
687
688 loss_step_ = partial(
689 loss_step,
690 vae,
691 noise_scheduler,
692 unet,
693 text_encoder,
694 args.num_class_images != 0,
695 args.prior_loss_weight,
696 args.seed,
697 )
698
703 if args.find_lr: 699 if args.find_lr:
704 lr_finder = LRFinder( 700 lr_finder = LRFinder(
705 accelerator=accelerator, 701 accelerator=accelerator,
706 optimizer=optimizer, 702 optimizer=optimizer,
707 model=text_encoder,
708 train_dataloader=train_dataloader, 703 train_dataloader=train_dataloader,
709 val_dataloader=val_dataloader, 704 val_dataloader=val_dataloader,
710 **strategy, 705 callbacks=callbacks,
711 ) 706 )
712 lr_finder.run(num_epochs=100, end_lr=1e3) 707 lr_finder.run(num_epochs=100, end_lr=1e3)
713 708
714 plt.savefig(output_dir.joinpath("lr.png"), dpi=300) 709 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
715 plt.close() 710 plt.close()
716 else: 711 else:
717 trainer( 712 if accelerator.is_main_process:
713 accelerator.init_trackers("textual_inversion")
714
715 train_loop(
716 accelerator=accelerator,
718 optimizer=optimizer, 717 optimizer=optimizer,
719 lr_scheduler=lr_scheduler, 718 lr_scheduler=lr_scheduler,
720 num_train_epochs=args.num_train_epochs, 719 train_dataloader=train_dataloader,
720 val_dataloader=val_dataloader,
721 loss_step=loss_step_,
721 sample_frequency=args.sample_frequency, 722 sample_frequency=args.sample_frequency,
722 checkpoint_frequency=args.checkpoint_frequency, 723 checkpoint_frequency=args.checkpoint_frequency,
723 global_step_offset=global_step_offset, 724 global_step_offset=global_step_offset,
724 prior_loss_weight=args.prior_loss_weight, 725 callbacks=callbacks,
725 callbacks=strategy,
726 ) 726 )
727 727
728 accelerator.end_training()
729
728 730
729if __name__ == "__main__": 731if __name__ == "__main__":
730 main() 732 main()