diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index e3c8525..f1dca7f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -598,7 +598,8 @@ def main(): | |||
598 | num_train_epochs = args.num_train_epochs | 598 | num_train_epochs = args.num_train_epochs |
599 | 599 | ||
600 | if num_train_epochs is None: | 600 | if num_train_epochs is None: |
601 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) | 601 | num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size |
602 | num_train_epochs = math.ceil(args.num_train_steps / num_images) | ||
602 | 603 | ||
603 | params_to_optimize = (unet.parameters(), ) | 604 | params_to_optimize = (unet.parameters(), ) |
604 | if args.train_text_encoder_epochs != 0: | 605 | if args.train_text_encoder_epochs != 0: |