summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2c884d2..3a25efa 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -611,7 +611,9 @@ def main():
611 num_train_epochs = args.num_train_epochs 611 num_train_epochs = args.num_train_epochs
612 sample_frequency = args.sample_frequency 612 sample_frequency = args.sample_frequency
613 if num_train_epochs is None: 613 if num_train_epochs is None:
614 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 614 num_train_epochs = math.ceil(
615 args.num_train_steps / len(datamodule.train_dataset)
616 ) * args.gradient_accumulation_steps
615 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 617 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
616 618
617 params_to_optimize = (unet.parameters(), ) 619 params_to_optimize = (unet.parameters(), )