summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-31 09:34:55 +0200
committerVolpeon <git@volpeon.ink>2023-03-31 09:34:55 +0200
commitbe3e05e47cded8487aaa787c54aa74770f9dcac8 (patch)
treed7ba88cbe7f72d6e910dd2d7f1916bb198e701fd /train_dreambooth.py
parentFix (diff)
downloadtextual-inversion-diff-be3e05e47cded8487aaa787c54aa74770f9dcac8.tar.gz
textual-inversion-diff-be3e05e47cded8487aaa787c54aa74770f9dcac8.tar.bz2
textual-inversion-diff-be3e05e47cded8487aaa787c54aa74770f9dcac8.zip
Fix
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index e3c8525..f1dca7f 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -598,7 +598,8 @@ def main():
598 num_train_epochs = args.num_train_epochs 598 num_train_epochs = args.num_train_epochs
599 599
600 if num_train_epochs is None: 600 if num_train_epochs is None:
601 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 601 num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size
602 num_train_epochs = math.ceil(args.num_train_steps / num_images)
602 603
603 params_to_optimize = (unet.parameters(), ) 604 params_to_optimize = (unet.parameters(), )
604 if args.train_text_encoder_epochs != 0: 605 if args.train_text_encoder_epochs != 0: