summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 14:12:22 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 14:12:22 +0100
commitc11062dd765d77f8e78ff5403541fe43085ad763 (patch)
treea78884565997db906b5b709cd11b62449316b855 /train_ti.py
parentRemoved PromptProcessor, modularized training loop (diff)
downloadtextual-inversion-diff-c11062dd765d77f8e78ff5403541fe43085ad763.tar.gz
textual-inversion-diff-c11062dd765d77f8e78ff5403541fe43085ad763.tar.bz2
textual-inversion-diff-c11062dd765d77f8e78ff5403541fe43085ad763.zip
Simplified step calculations
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py24
1 files changed, 11 insertions, 13 deletions
diff --git a/train_ti.py b/train_ti.py
index 8c86586..3f4e739 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -750,8 +750,6 @@ def main():
750 args.sample_steps 750 args.sample_steps
751 ) 751 )
752 752
753 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
754
755 if args.find_lr: 753 if args.find_lr:
756 lr_scheduler = None 754 lr_scheduler = None
757 else: 755 else:
@@ -765,9 +763,9 @@ def main():
765 warmup_exp=args.lr_warmup_exp, 763 warmup_exp=args.lr_warmup_exp,
766 annealing_exp=args.lr_annealing_exp, 764 annealing_exp=args.lr_annealing_exp,
767 cycles=args.lr_cycles, 765 cycles=args.lr_cycles,
766 train_epochs=args.num_train_epochs,
768 warmup_epochs=args.lr_warmup_epochs, 767 warmup_epochs=args.lr_warmup_epochs,
769 num_train_epochs=args.num_train_epochs, 768 num_training_steps_per_epoch=len(train_dataloader),
770 num_update_steps_per_epoch=num_update_steps_per_epoch,
771 gradient_accumulation_steps=args.gradient_accumulation_steps 769 gradient_accumulation_steps=args.gradient_accumulation_steps
772 ) 770 )
773 771
@@ -826,13 +824,13 @@ def main():
826 return {"ema_decay": ema_embeddings.decay} 824 return {"ema_decay": ema_embeddings.decay}
827 return {} 825 return {}
828 826
829 loop = partial( 827 loss_step_ = partial(
830 loss_step, 828 loss_step,
831 vae, 829 vae,
832 noise_scheduler, 830 noise_scheduler,
833 unet, 831 unet,
834 text_encoder, 832 text_encoder,
835 args.num_class_images, 833 args.num_class_images != 0,
836 args.prior_loss_weight, 834 args.prior_loss_weight,
837 args.seed, 835 args.seed,
838 ) 836 )
@@ -869,12 +867,12 @@ def main():
869 867
870 if args.find_lr: 868 if args.find_lr:
871 lr_finder = LRFinder( 869 lr_finder = LRFinder(
872 accelerator, 870 accelerator=accelerator,
873 text_encoder, 871 optimizer=optimizer,
874 optimizer, 872 model=text_encoder,
875 train_dataloader, 873 train_dataloader=train_dataloader,
876 val_dataloader, 874 val_dataloader=val_dataloader,
877 loop, 875 loss_step=loss_step_,
878 on_train=on_train, 876 on_train=on_train,
879 on_eval=on_eval, 877 on_eval=on_eval,
880 on_after_optimize=on_after_optimize, 878 on_after_optimize=on_after_optimize,
@@ -892,7 +890,7 @@ def main():
892 checkpointer=checkpointer, 890 checkpointer=checkpointer,
893 train_dataloader=train_dataloader, 891 train_dataloader=train_dataloader,
894 val_dataloader=val_dataloader, 892 val_dataloader=val_dataloader,
895 loss_step=loop, 893 loss_step=loss_step_,
896 sample_frequency=args.sample_frequency, 894 sample_frequency=args.sample_frequency,
897 sample_steps=args.sample_steps, 895 sample_steps=args.sample_steps,
898 checkpoint_frequency=args.checkpoint_frequency, 896 checkpoint_frequency=args.checkpoint_frequency,