From bf24028a869c849a29d23b05db0284a158d201f0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 31 Mar 2023 22:17:26 +0200 Subject: Update --- train_dreambooth.py | 4 +++- train_lora.py | 4 +++- train_ti.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 2c884d2..3a25efa 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -611,7 +611,9 @@ def main(): num_train_epochs = args.num_train_epochs sample_frequency = args.sample_frequency if num_train_epochs is None: - num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + num_train_epochs = math.ceil( + args.num_train_steps / len(datamodule.train_dataset) + ) * args.gradient_accumulation_steps sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) params_to_optimize = (unet.parameters(), ) diff --git a/train_lora.py b/train_lora.py index 59beb09..f74a438 100644 --- a/train_lora.py +++ b/train_lora.py @@ -643,7 +643,9 @@ def main(): num_train_epochs = args.num_train_epochs sample_frequency = args.sample_frequency if num_train_epochs is None: - num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + num_train_epochs = math.ceil( + args.num_train_steps / len(datamodule.train_dataset) + ) * args.gradient_accumulation_steps sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( diff --git a/train_ti.py b/train_ti.py index 83043ad..dd015f9 100644 --- a/train_ti.py +++ b/train_ti.py @@ -774,7 +774,9 @@ def main(): num_train_epochs = args.num_train_epochs sample_frequency = args.sample_frequency if num_train_epochs is None: - num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + num_train_epochs = math.ceil( + args.num_train_steps / len(datamodule.train_dataset) + ) * args.gradient_accumulation_steps sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( -- cgit v1.2.3-70-g09d2