diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 49 |
1 files changed, 23 insertions, 26 deletions
diff --git a/train_ti.py b/train_ti.py index 78c1b5c..97e4e72 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -17,7 +17,7 @@ from slugify import slugify | |||
17 | from util import load_config, load_embeddings_from_dir | 17 | from util import load_config, load_embeddings_from_dir |
18 | from data.csv import VlpnDataModule, VlpnDataItem | 18 | from data.csv import VlpnDataModule, VlpnDataItem |
19 | from trainer_old.base import Checkpointer | 19 | from trainer_old.base import Checkpointer |
20 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models | 20 | from training.functional import train, loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models |
21 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
22 | from training.lr import LRFinder | 22 | from training.lr import LRFinder |
23 | from training.util import EMAModel, save_args | 23 | from training.util import EMAModel, save_args |
@@ -703,17 +703,27 @@ def main(): | |||
703 | warmup_epochs=args.lr_warmup_epochs, | 703 | warmup_epochs=args.lr_warmup_epochs, |
704 | ) | 704 | ) |
705 | 705 | ||
706 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
707 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
708 | ) | ||
709 | |||
710 | vae.to(accelerator.device, dtype=weight_dtype) | ||
711 | |||
712 | if args.use_ema: | 706 | if args.use_ema: |
713 | ema_embeddings.to(accelerator.device) | 707 | ema_embeddings.to(accelerator.device) |
714 | 708 | ||
715 | if args.gradient_checkpointing: | 709 | trainer = partial( |
716 | unet.train() | 710 | train, |
711 | accelerator=accelerator, | ||
712 | vae=vae, | ||
713 | unet=unet, | ||
714 | text_encoder=text_encoder, | ||
715 | noise_scheduler=noise_scheduler, | ||
716 | train_dataloader=train_dataloader, | ||
717 | val_dataloader=val_dataloader, | ||
718 | dtype=weight_dtype, | ||
719 | seed=args.seed, | ||
720 | ) | ||
721 | |||
722 | def on_prepare(): | ||
723 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | ||
724 | |||
725 | if args.gradient_checkpointing: | ||
726 | unet.train() | ||
717 | 727 | ||
718 | @contextmanager | 728 | @contextmanager |
719 | def on_train(epoch: int): | 729 | def on_train(epoch: int): |
@@ -752,16 +762,6 @@ def main(): | |||
752 | return {"ema_decay": ema_embeddings.decay} | 762 | return {"ema_decay": ema_embeddings.decay} |
753 | return {} | 763 | return {} |
754 | 764 | ||
755 | loss_step_ = partial( | ||
756 | loss_step, | ||
757 | vae, | ||
758 | noise_scheduler, | ||
759 | unet, | ||
760 | text_encoder, | ||
761 | args.prior_loss_weight, | ||
762 | args.seed, | ||
763 | ) | ||
764 | |||
765 | checkpointer = TextualInversionCheckpointer( | 765 | checkpointer = TextualInversionCheckpointer( |
766 | dtype=weight_dtype, | 766 | dtype=weight_dtype, |
767 | train_dataloader=train_dataloader, | 767 | train_dataloader=train_dataloader, |
@@ -803,18 +803,15 @@ def main(): | |||
803 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | 803 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
804 | plt.close() | 804 | plt.close() |
805 | else: | 805 | else: |
806 | train_loop( | 806 | trainer( |
807 | accelerator=accelerator, | ||
808 | optimizer=optimizer, | 807 | optimizer=optimizer, |
809 | lr_scheduler=lr_scheduler, | 808 | lr_scheduler=lr_scheduler, |
810 | model=text_encoder, | 809 | num_train_epochs=args.num_train_epochs, |
811 | train_dataloader=train_dataloader, | ||
812 | val_dataloader=val_dataloader, | ||
813 | loss_step=loss_step_, | ||
814 | sample_frequency=args.sample_frequency, | 810 | sample_frequency=args.sample_frequency, |
815 | checkpoint_frequency=args.checkpoint_frequency, | 811 | checkpoint_frequency=args.checkpoint_frequency, |
816 | global_step_offset=global_step_offset, | 812 | global_step_offset=global_step_offset, |
817 | num_epochs=args.num_train_epochs, | 813 | prior_loss_weight=args.prior_loss_weight, |
814 | on_prepare=on_prepare, | ||
818 | on_log=on_log, | 815 | on_log=on_log, |
819 | on_train=on_train, | 816 | on_train=on_train, |
820 | on_after_optimize=on_after_optimize, | 817 | on_after_optimize=on_after_optimize, |