summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-15 09:25:30 +0100
committerVolpeon <git@volpeon.ink>2023-01-15 09:25:30 +0100
commit5b9a3de142e7a645573b4f4a8c1ce9c59746ab08 (patch)
treec551bd9a3f2f85f7aeb1e7f4bd3b2ebd0cb20450 /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-5b9a3de142e7a645573b4f4a8c1ce9c59746ab08.tar.gz
textual-inversion-diff-5b9a3de142e7a645573b4f4a8c1ce9c59746ab08.tar.bz2
textual-inversion-diff-5b9a3de142e7a645573b4f4a8c1ce9c59746ab08.zip
Added functional trainer
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py49
1 files changed, 23 insertions, 26 deletions
diff --git a/train_ti.py b/train_ti.py
index 78c1b5c..97e4e72 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -17,7 +17,7 @@ from slugify import slugify
17from util import load_config, load_embeddings_from_dir 17from util import load_config, load_embeddings_from_dir
18from data.csv import VlpnDataModule, VlpnDataItem 18from data.csv import VlpnDataModule, VlpnDataItem
19from trainer_old.base import Checkpointer 19from trainer_old.base import Checkpointer
20from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models 20from training.functional import train, loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
21from training.optimization import get_scheduler 21from training.optimization import get_scheduler
22from training.lr import LRFinder 22from training.lr import LRFinder
23from training.util import EMAModel, save_args 23from training.util import EMAModel, save_args
@@ -703,17 +703,27 @@ def main():
703 warmup_epochs=args.lr_warmup_epochs, 703 warmup_epochs=args.lr_warmup_epochs,
704 ) 704 )
705 705
706 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
707 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
708 )
709
710 vae.to(accelerator.device, dtype=weight_dtype)
711
712 if args.use_ema: 706 if args.use_ema:
713 ema_embeddings.to(accelerator.device) 707 ema_embeddings.to(accelerator.device)
714 708
715 if args.gradient_checkpointing: 709 trainer = partial(
716 unet.train() 710 train,
711 accelerator=accelerator,
712 vae=vae,
713 unet=unet,
714 text_encoder=text_encoder,
715 noise_scheduler=noise_scheduler,
716 train_dataloader=train_dataloader,
717 val_dataloader=val_dataloader,
718 dtype=weight_dtype,
719 seed=args.seed,
720 )
721
722 def on_prepare():
723 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
724
725 if args.gradient_checkpointing:
726 unet.train()
717 727
718 @contextmanager 728 @contextmanager
719 def on_train(epoch: int): 729 def on_train(epoch: int):
@@ -752,16 +762,6 @@ def main():
752 return {"ema_decay": ema_embeddings.decay} 762 return {"ema_decay": ema_embeddings.decay}
753 return {} 763 return {}
754 764
755 loss_step_ = partial(
756 loss_step,
757 vae,
758 noise_scheduler,
759 unet,
760 text_encoder,
761 args.prior_loss_weight,
762 args.seed,
763 )
764
765 checkpointer = TextualInversionCheckpointer( 765 checkpointer = TextualInversionCheckpointer(
766 dtype=weight_dtype, 766 dtype=weight_dtype,
767 train_dataloader=train_dataloader, 767 train_dataloader=train_dataloader,
@@ -803,18 +803,15 @@ def main():
803 plt.savefig(output_dir.joinpath("lr.png"), dpi=300) 803 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
804 plt.close() 804 plt.close()
805 else: 805 else:
806 train_loop( 806 trainer(
807 accelerator=accelerator,
808 optimizer=optimizer, 807 optimizer=optimizer,
809 lr_scheduler=lr_scheduler, 808 lr_scheduler=lr_scheduler,
810 model=text_encoder, 809 num_train_epochs=args.num_train_epochs,
811 train_dataloader=train_dataloader,
812 val_dataloader=val_dataloader,
813 loss_step=loss_step_,
814 sample_frequency=args.sample_frequency, 810 sample_frequency=args.sample_frequency,
815 checkpoint_frequency=args.checkpoint_frequency, 811 checkpoint_frequency=args.checkpoint_frequency,
816 global_step_offset=global_step_offset, 812 global_step_offset=global_step_offset,
817 num_epochs=args.num_train_epochs, 813 prior_loss_weight=args.prior_loss_weight,
814 on_prepare=on_prepare,
818 on_log=on_log, 815 on_log=on_log,
819 on_train=on_train, 816 on_train=on_train,
820 on_after_optimize=on_after_optimize, 817 on_after_optimize=on_after_optimize,