diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index acb8287..e3c8525 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -598,7 +598,7 @@ def main(): | |||
598 | num_train_epochs = args.num_train_epochs | 598 | num_train_epochs = args.num_train_epochs |
599 | 599 | ||
600 | if num_train_epochs is None: | 600 | if num_train_epochs is None: |
601 | num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) | 601 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
602 | 602 | ||
603 | params_to_optimize = (unet.parameters(), ) | 603 | params_to_optimize = (unet.parameters(), ) |
604 | if args.train_text_encoder_epochs != 0: | 604 | if args.train_text_encoder_epochs != 0: |