diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index 6f8644b..9975462 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -630,7 +630,8 @@ def main(): | |||
630 | num_train_epochs = args.num_train_epochs | 630 | num_train_epochs = args.num_train_epochs |
631 | 631 | ||
632 | if num_train_epochs is None: | 632 | if num_train_epochs is None: |
633 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) | 633 | num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size |
634 | num_train_epochs = math.ceil(args.num_train_steps / num_images) | ||
634 | 635 | ||
635 | optimizer = create_optimizer( | 636 | optimizer = create_optimizer( |
636 | itertools.chain( | 637 | itertools.chain( |