summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py82
1 files changed, 21 insertions, 61 deletions
diff --git a/train_ti.py b/train_ti.py
index 4bac736..77dec12 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -10,15 +10,13 @@ import torch.utils.checkpoint
10from accelerate import Accelerator 10from accelerate import Accelerator
11from accelerate.logging import get_logger 11from accelerate.logging import get_logger
12from accelerate.utils import LoggerType, set_seed 12from accelerate.utils import LoggerType, set_seed
13import matplotlib.pyplot as plt
14from slugify import slugify 13from slugify import slugify
15 14
16from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
17from data.csv import VlpnDataModule, VlpnDataItem 16from data.csv import VlpnDataModule, VlpnDataItem
18from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models 17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models
19from training.strategy.ti import textual_inversion_strategy 18from training.strategy.ti import textual_inversion_strategy
20from training.optimization import get_scheduler 19from training.optimization import get_scheduler
21from training.lr import LRFinder
22from training.util import save_args 20from training.util import save_args
23 21
24logger = get_logger(__name__) 22logger = get_logger(__name__)
@@ -644,23 +642,33 @@ def main():
644 warmup_epochs=args.lr_warmup_epochs, 642 warmup_epochs=args.lr_warmup_epochs,
645 ) 643 )
646 644
647 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 645 trainer = partial(
648 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 646 train,
649 )
650
651 vae.to(accelerator.device, dtype=weight_dtype)
652
653 callbacks = textual_inversion_strategy(
654 accelerator=accelerator, 647 accelerator=accelerator,
655 unet=unet, 648 unet=unet,
656 text_encoder=text_encoder, 649 text_encoder=text_encoder,
657 tokenizer=tokenizer,
658 vae=vae, 650 vae=vae,
659 sample_scheduler=sample_scheduler, 651 noise_scheduler=noise_scheduler,
660 train_dataloader=train_dataloader, 652 train_dataloader=train_dataloader,
661 val_dataloader=val_dataloader, 653 val_dataloader=val_dataloader,
662 output_dir=output_dir, 654 dtype=weight_dtype,
663 seed=args.seed, 655 seed=args.seed,
656 callbacks_fn=textual_inversion_strategy
657 )
658
659 trainer(
660 optimizer=optimizer,
661 lr_scheduler=lr_scheduler,
662 num_train_epochs=args.num_train_epochs,
663 sample_frequency=args.sample_frequency,
664 checkpoint_frequency=args.checkpoint_frequency,
665 global_step_offset=global_step_offset,
666 with_prior_preservation=args.num_class_images != 0,
667 prior_loss_weight=args.prior_loss_weight,
668 # --
669 tokenizer=tokenizer,
670 sample_scheduler=sample_scheduler,
671 output_dir=output_dir,
664 placeholder_tokens=args.placeholder_tokens, 672 placeholder_tokens=args.placeholder_tokens,
665 placeholder_token_ids=placeholder_token_ids, 673 placeholder_token_ids=placeholder_token_ids,
666 learning_rate=args.learning_rate, 674 learning_rate=args.learning_rate,
@@ -679,54 +687,6 @@ def main():
679 sample_image_size=args.sample_image_size, 687 sample_image_size=args.sample_image_size,
680 ) 688 )
681 689
682 for model in (unet, text_encoder, vae):
683 model.requires_grad_(False)
684 model.eval()
685
686 callbacks.on_prepare()
687
688 loss_step_ = partial(
689 loss_step,
690 vae,
691 noise_scheduler,
692 unet,
693 text_encoder,
694 args.num_class_images != 0,
695 args.prior_loss_weight,
696 args.seed,
697 )
698
699 if args.find_lr:
700 lr_finder = LRFinder(
701 accelerator=accelerator,
702 optimizer=optimizer,
703 train_dataloader=train_dataloader,
704 val_dataloader=val_dataloader,
705 callbacks=callbacks,
706 )
707 lr_finder.run(num_epochs=100, end_lr=1e3)
708
709 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
710 plt.close()
711 else:
712 if accelerator.is_main_process:
713 accelerator.init_trackers("textual_inversion")
714
715 train_loop(
716 accelerator=accelerator,
717 optimizer=optimizer,
718 lr_scheduler=lr_scheduler,
719 train_dataloader=train_dataloader,
720 val_dataloader=val_dataloader,
721 loss_step=loss_step_,
722 sample_frequency=args.sample_frequency,
723 checkpoint_frequency=args.checkpoint_frequency,
724 global_step_offset=global_step_offset,
725 callbacks=callbacks,
726 )
727
728 accelerator.end_training()
729
730 690
731if __name__ == "__main__": 691if __name__ == "__main__":
732 main() 692 main()