diff options
author | Volpeon <git@volpeon.ink> | 2023-03-31 22:17:26 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-31 22:17:26 +0200 |
commit | bf24028a869c849a29d23b05db0284a158d201f0 (patch) | |
tree | dc90ed7d2c6da10ded6f2c98d0f504f9372b6683 /train_lora.py | |
parent | Update (diff) | |
download | textual-inversion-diff-bf24028a869c849a29d23b05db0284a158d201f0.tar.gz textual-inversion-diff-bf24028a869c849a29d23b05db0284a158d201f0.tar.bz2 textual-inversion-diff-bf24028a869c849a29d23b05db0284a158d201f0.zip |
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index 59beb09..f74a438 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -643,7 +643,9 @@ def main(): | |||
643 | num_train_epochs = args.num_train_epochs | 643 | num_train_epochs = args.num_train_epochs |
644 | sample_frequency = args.sample_frequency | 644 | sample_frequency = args.sample_frequency |
645 | if num_train_epochs is None: | 645 | if num_train_epochs is None: |
646 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) | 646 | num_train_epochs = math.ceil( |
647 | args.num_train_steps / len(datamodule.train_dataset) | ||
648 | ) * args.gradient_accumulation_steps | ||
647 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 649 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
648 | 650 | ||
649 | optimizer = create_optimizer( | 651 | optimizer = create_optimizer( |