diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/train_ti.py b/train_ti.py index 8c86586..3f4e739 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -750,8 +750,6 @@ def main(): | |||
750 | args.sample_steps | 750 | args.sample_steps |
751 | ) | 751 | ) |
752 | 752 | ||
753 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||
754 | |||
755 | if args.find_lr: | 753 | if args.find_lr: |
756 | lr_scheduler = None | 754 | lr_scheduler = None |
757 | else: | 755 | else: |
@@ -765,9 +763,9 @@ def main(): | |||
765 | warmup_exp=args.lr_warmup_exp, | 763 | warmup_exp=args.lr_warmup_exp, |
766 | annealing_exp=args.lr_annealing_exp, | 764 | annealing_exp=args.lr_annealing_exp, |
767 | cycles=args.lr_cycles, | 765 | cycles=args.lr_cycles, |
766 | train_epochs=args.num_train_epochs, | ||
768 | warmup_epochs=args.lr_warmup_epochs, | 767 | warmup_epochs=args.lr_warmup_epochs, |
769 | num_train_epochs=args.num_train_epochs, | 768 | num_training_steps_per_epoch=len(train_dataloader), |
770 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
771 | gradient_accumulation_steps=args.gradient_accumulation_steps | 769 | gradient_accumulation_steps=args.gradient_accumulation_steps |
772 | ) | 770 | ) |
773 | 771 | ||
@@ -826,13 +824,13 @@ def main(): | |||
826 | return {"ema_decay": ema_embeddings.decay} | 824 | return {"ema_decay": ema_embeddings.decay} |
827 | return {} | 825 | return {} |
828 | 826 | ||
829 | loop = partial( | 827 | loss_step_ = partial( |
830 | loss_step, | 828 | loss_step, |
831 | vae, | 829 | vae, |
832 | noise_scheduler, | 830 | noise_scheduler, |
833 | unet, | 831 | unet, |
834 | text_encoder, | 832 | text_encoder, |
835 | args.num_class_images, | 833 | args.num_class_images != 0, |
836 | args.prior_loss_weight, | 834 | args.prior_loss_weight, |
837 | args.seed, | 835 | args.seed, |
838 | ) | 836 | ) |
@@ -869,12 +867,12 @@ def main(): | |||
869 | 867 | ||
870 | if args.find_lr: | 868 | if args.find_lr: |
871 | lr_finder = LRFinder( | 869 | lr_finder = LRFinder( |
872 | accelerator, | 870 | accelerator=accelerator, |
873 | text_encoder, | 871 | optimizer=optimizer, |
874 | optimizer, | 872 | model=text_encoder, |
875 | train_dataloader, | 873 | train_dataloader=train_dataloader, |
876 | val_dataloader, | 874 | val_dataloader=val_dataloader, |
877 | loop, | 875 | loss_step=loss_step_, |
878 | on_train=on_train, | 876 | on_train=on_train, |
879 | on_eval=on_eval, | 877 | on_eval=on_eval, |
880 | on_after_optimize=on_after_optimize, | 878 | on_after_optimize=on_after_optimize, |
@@ -892,7 +890,7 @@ def main(): | |||
892 | checkpointer=checkpointer, | 890 | checkpointer=checkpointer, |
893 | train_dataloader=train_dataloader, | 891 | train_dataloader=train_dataloader, |
894 | val_dataloader=val_dataloader, | 892 | val_dataloader=val_dataloader, |
895 | loss_step=loop, | 893 | loss_step=loss_step_, |
896 | sample_frequency=args.sample_frequency, | 894 | sample_frequency=args.sample_frequency, |
897 | sample_steps=args.sample_steps, | 895 | sample_steps=args.sample_steps, |
898 | checkpoint_frequency=args.checkpoint_frequency, | 896 | checkpoint_frequency=args.checkpoint_frequency, |