From c11062dd765d77f8e78ff5403541fe43085ad763 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 14:12:22 +0100 Subject: Simplified step calculations --- train_ti.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 8c86586..3f4e739 100644 --- a/train_ti.py +++ b/train_ti.py @@ -750,8 +750,6 @@ def main(): args.sample_steps ) - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.find_lr: lr_scheduler = None else: @@ -765,9 +763,9 @@ def main(): warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, warmup_epochs=args.lr_warmup_epochs, - num_train_epochs=args.num_train_epochs, - num_update_steps_per_epoch=num_update_steps_per_epoch, + num_training_steps_per_epoch=len(train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps ) @@ -826,13 +824,13 @@ def main(): return {"ema_decay": ema_embeddings.decay} return {} - loop = partial( + loss_step_ = partial( loss_step, vae, noise_scheduler, unet, text_encoder, - args.num_class_images, + args.num_class_images != 0, args.prior_loss_weight, args.seed, ) @@ -869,12 +867,12 @@ def main(): if args.find_lr: lr_finder = LRFinder( - accelerator, - text_encoder, - optimizer, - train_dataloader, - val_dataloader, - loop, + accelerator=accelerator, + optimizer=optimizer, + model=text_encoder, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, on_train=on_train, on_eval=on_eval, on_after_optimize=on_after_optimize, @@ -892,7 +890,7 @@ def main(): checkpointer=checkpointer, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - loss_step=loop, + loss_step=loss_step_, sample_frequency=args.sample_frequency, sample_steps=args.sample_steps, checkpoint_frequency=args.checkpoint_frequency, -- cgit v1.2.3-54-g00ecf