From 179a45253a5b3712f32bd127f693a6bb810a9c17 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 28 Mar 2023 16:24:22 +0200 Subject: Support num_train_steps arg again --- train_dreambooth.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 9345797..acb8287 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -4,6 +4,7 @@ import logging import itertools from pathlib import Path from functools import partial +import math import torch import torch.utils.checkpoint @@ -189,13 +190,12 @@ def parse_args(): parser.add_argument( "--num_train_epochs", type=int, - default=100 + default=None ) parser.add_argument( - "--max_train_steps", + "--num_train_steps", type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + default=2000 ) parser.add_argument( "--gradient_accumulation_steps", @@ -595,6 +595,11 @@ def main(): ) datamodule.setup() + num_train_epochs = args.num_train_epochs + + if num_train_epochs is None: + num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) + params_to_optimize = (unet.parameters(), ) if args.train_text_encoder_epochs != 0: params_to_optimize += ( @@ -619,7 +624,7 @@ def main(): annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, end_lr=1e2, - train_epochs=args.num_train_epochs, + train_epochs=num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) @@ -631,7 +636,7 @@ def main(): seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, - num_train_epochs=args.num_train_epochs, + num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, sample_frequency=args.sample_frequency, offset_noise_strength=args.offset_noise_strength, -- cgit v1.2.3-54-g00ecf