diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-13 14:12:22 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-13 14:12:22 +0100 |
| commit | c11062dd765d77f8e78ff5403541fe43085ad763 (patch) | |
| tree | a78884565997db906b5b709cd11b62449316b855 /train_ti.py | |
| parent | Removed PromptProcessor, modularized training loop (diff) | |
| download | textual-inversion-diff-c11062dd765d77f8e78ff5403541fe43085ad763.tar.gz textual-inversion-diff-c11062dd765d77f8e78ff5403541fe43085ad763.tar.bz2 textual-inversion-diff-c11062dd765d77f8e78ff5403541fe43085ad763.zip | |
Simplified step calculations
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/train_ti.py b/train_ti.py index 8c86586..3f4e739 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -750,8 +750,6 @@ def main(): | |||
| 750 | args.sample_steps | 750 | args.sample_steps |
| 751 | ) | 751 | ) |
| 752 | 752 | ||
| 753 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||
| 754 | |||
| 755 | if args.find_lr: | 753 | if args.find_lr: |
| 756 | lr_scheduler = None | 754 | lr_scheduler = None |
| 757 | else: | 755 | else: |
| @@ -765,9 +763,9 @@ def main(): | |||
| 765 | warmup_exp=args.lr_warmup_exp, | 763 | warmup_exp=args.lr_warmup_exp, |
| 766 | annealing_exp=args.lr_annealing_exp, | 764 | annealing_exp=args.lr_annealing_exp, |
| 767 | cycles=args.lr_cycles, | 765 | cycles=args.lr_cycles, |
| 766 | train_epochs=args.num_train_epochs, | ||
| 768 | warmup_epochs=args.lr_warmup_epochs, | 767 | warmup_epochs=args.lr_warmup_epochs, |
| 769 | num_train_epochs=args.num_train_epochs, | 768 | num_training_steps_per_epoch=len(train_dataloader), |
| 770 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
| 771 | gradient_accumulation_steps=args.gradient_accumulation_steps | 769 | gradient_accumulation_steps=args.gradient_accumulation_steps |
| 772 | ) | 770 | ) |
| 773 | 771 | ||
| @@ -826,13 +824,13 @@ def main(): | |||
| 826 | return {"ema_decay": ema_embeddings.decay} | 824 | return {"ema_decay": ema_embeddings.decay} |
| 827 | return {} | 825 | return {} |
| 828 | 826 | ||
| 829 | loop = partial( | 827 | loss_step_ = partial( |
| 830 | loss_step, | 828 | loss_step, |
| 831 | vae, | 829 | vae, |
| 832 | noise_scheduler, | 830 | noise_scheduler, |
| 833 | unet, | 831 | unet, |
| 834 | text_encoder, | 832 | text_encoder, |
| 835 | args.num_class_images, | 833 | args.num_class_images != 0, |
| 836 | args.prior_loss_weight, | 834 | args.prior_loss_weight, |
| 837 | args.seed, | 835 | args.seed, |
| 838 | ) | 836 | ) |
| @@ -869,12 +867,12 @@ def main(): | |||
| 869 | 867 | ||
| 870 | if args.find_lr: | 868 | if args.find_lr: |
| 871 | lr_finder = LRFinder( | 869 | lr_finder = LRFinder( |
| 872 | accelerator, | 870 | accelerator=accelerator, |
| 873 | text_encoder, | 871 | optimizer=optimizer, |
| 874 | optimizer, | 872 | model=text_encoder, |
| 875 | train_dataloader, | 873 | train_dataloader=train_dataloader, |
| 876 | val_dataloader, | 874 | val_dataloader=val_dataloader, |
| 877 | loop, | 875 | loss_step=loss_step_, |
| 878 | on_train=on_train, | 876 | on_train=on_train, |
| 879 | on_eval=on_eval, | 877 | on_eval=on_eval, |
| 880 | on_after_optimize=on_after_optimize, | 878 | on_after_optimize=on_after_optimize, |
| @@ -892,7 +890,7 @@ def main(): | |||
| 892 | checkpointer=checkpointer, | 890 | checkpointer=checkpointer, |
| 893 | train_dataloader=train_dataloader, | 891 | train_dataloader=train_dataloader, |
| 894 | val_dataloader=val_dataloader, | 892 | val_dataloader=val_dataloader, |
| 895 | loss_step=loop, | 893 | loss_step=loss_step_, |
| 896 | sample_frequency=args.sample_frequency, | 894 | sample_frequency=args.sample_frequency, |
| 897 | sample_steps=args.sample_steps, | 895 | sample_steps=args.sample_steps, |
| 898 | checkpoint_frequency=args.checkpoint_frequency, | 896 | checkpoint_frequency=args.checkpoint_frequency, |
