diff options
author | Volpeon <git@volpeon.ink> | 2023-03-31 09:34:55 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-31 09:34:55 +0200 |
commit | be3e05e47cded8487aaa787c54aa74770f9dcac8 (patch) | |
tree | d7ba88cbe7f72d6e910dd2d7f1916bb198e701fd /train_dreambooth.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-be3e05e47cded8487aaa787c54aa74770f9dcac8.tar.gz textual-inversion-diff-be3e05e47cded8487aaa787c54aa74770f9dcac8.tar.bz2 textual-inversion-diff-be3e05e47cded8487aaa787c54aa74770f9dcac8.zip |
Fix
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index e3c8525..f1dca7f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -598,7 +598,8 @@ def main(): | |||
598 | num_train_epochs = args.num_train_epochs | 598 | num_train_epochs = args.num_train_epochs |
599 | 599 | ||
600 | if num_train_epochs is None: | 600 | if num_train_epochs is None: |
601 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) | 601 | num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size |
602 | num_train_epochs = math.ceil(args.num_train_steps / num_images) | ||
602 | 603 | ||
603 | params_to_optimize = (unet.parameters(), ) | 604 | params_to_optimize = (unet.parameters(), ) |
604 | if args.train_text_encoder_epochs != 0: | 605 | if args.train_text_encoder_epochs != 0: |