From 179a45253a5b3712f32bd127f693a6bb810a9c17 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 28 Mar 2023 16:24:22 +0200 Subject: Support num_train_steps arg again --- train_lora.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 7ecddf0..a9c6e52 100644 --- a/train_lora.py +++ b/train_lora.py @@ -4,6 +4,7 @@ import logging import itertools from pathlib import Path from functools import partial +import math import torch import torch.utils.checkpoint @@ -178,13 +179,12 @@ def parse_args(): parser.add_argument( "--num_train_epochs", type=int, - default=100 + default=None ) parser.add_argument( - "--max_train_steps", + "--num_train_steps", type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + default=2000 ) parser.add_argument( "--gradient_accumulation_steps", @@ -627,6 +627,11 @@ def main(): ) datamodule.setup() + num_train_epochs = args.num_train_epochs + + if num_train_epochs is None: + num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) + optimizer = create_optimizer( itertools.chain( unet.parameters(), @@ -647,7 +652,7 @@ def main(): annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, end_lr=1e2, - train_epochs=args.num_train_epochs, + train_epochs=num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) @@ -659,7 +664,7 @@ def main(): seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, - num_train_epochs=args.num_train_epochs, + num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, sample_frequency=args.sample_frequency, offset_noise_strength=args.offset_noise_strength, -- cgit v1.2.3-54-g00ecf