From c11062dd765d77f8e78ff5403541fe43085ad763 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 14:12:22 +0100 Subject: Simplified step calculations --- train_ti.py | 24 +++++++++++------------- training/common.py | 42 ++++++++++++++++++++++-------------------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/train_ti.py b/train_ti.py index 8c86586..3f4e739 100644 --- a/train_ti.py +++ b/train_ti.py @@ -750,8 +750,6 @@ def main(): args.sample_steps ) - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.find_lr: lr_scheduler = None else: @@ -765,9 +763,9 @@ def main(): warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, warmup_epochs=args.lr_warmup_epochs, - num_train_epochs=args.num_train_epochs, - num_update_steps_per_epoch=num_update_steps_per_epoch, + num_training_steps_per_epoch=len(train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps ) @@ -826,13 +824,13 @@ def main(): return {"ema_decay": ema_embeddings.decay} return {} - loop = partial( + loss_step_ = partial( loss_step, vae, noise_scheduler, unet, text_encoder, - args.num_class_images, + args.num_class_images != 0, args.prior_loss_weight, args.seed, ) @@ -869,12 +867,12 @@ def main(): if args.find_lr: lr_finder = LRFinder( - accelerator, - text_encoder, - optimizer, - train_dataloader, - val_dataloader, - loop, + accelerator=accelerator, + optimizer=optimizer, + model=text_encoder, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, on_train=on_train, on_eval=on_eval, on_after_optimize=on_after_optimize, @@ -892,7 +890,7 @@ def main(): checkpointer=checkpointer, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - loss_step=loop, + loss_step=loss_step_, sample_frequency=args.sample_frequency, sample_steps=args.sample_steps, checkpoint_frequency=args.checkpoint_frequency, diff --git a/training/common.py b/training/common.py index 842ac07..180396e 100644 --- a/training/common.py +++ b/training/common.py @@ -36,21 +36,24 @@ def get_scheduler( warmup_exp: int, annealing_exp: int, cycles: int, + train_epochs: int, warmup_epochs: int, optimizer: torch.optim.Optimizer, - num_train_epochs: int, - num_update_steps_per_epoch: int, + num_training_steps_per_epoch: int, gradient_accumulation_steps: int, ): - num_train_steps = num_train_epochs * num_update_steps_per_epoch - warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps + num_training_steps_per_epoch = math.ceil( + num_training_steps_per_epoch / gradient_accumulation_steps + ) * gradient_accumulation_steps + num_training_steps = train_epochs * num_training_steps_per_epoch + num_warmup_steps = warmup_epochs * num_training_steps_per_epoch if id == "one_cycle": min_lr = 0.04 if min_lr is None else min_lr / lr lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, - num_training_steps=num_train_steps * gradient_accumulation_steps, + num_training_steps=num_training_steps, warmup=warmup_func, annealing=annealing_func, warmup_exp=warmup_exp, @@ -58,21 +61,21 @@ def get_scheduler( min_lr=min_lr, ) elif id == "cosine_with_restarts": - cycles = cycles if cycles is not None else math.ceil( - math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) + if cycles is None: + cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=num_train_steps * gradient_accumulation_steps, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, num_cycles=cycles, ) else: lr_scheduler = get_scheduler_( id, optimizer=optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=num_train_steps * gradient_accumulation_steps, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, ) return lr_scheduler @@ -135,7 +138,7 @@ def loss_step( noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, - num_class_images: int, + with_prior: bool, prior_loss_weight: float, seed: int, step: int, @@ -184,7 +187,7 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if num_class_images != 0: + if with_prior: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -226,11 +229,10 @@ def train_loop( on_after_optimize: Callable[[float], None] = noop, on_eval: Callable[[], _GeneratorContextManager] = nullcontext ): - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) - num_train_steps = num_epochs * num_update_steps_per_epoch - + num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) - num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) + + num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs global_step = 0 @@ -244,14 +246,14 @@ def train_loop( max_acc_val = 0.0 local_progress_bar = tqdm( - range(num_update_steps_per_epoch + num_val_steps_per_epoch), + range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( - range(num_train_steps + num_val_steps), + range(num_training_steps + num_val_steps), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) @@ -309,7 +311,7 @@ def train_loop( local_progress_bar.set_postfix(**logs) - if global_step >= num_train_steps: + if global_step >= num_training_steps: break accelerator.wait_for_everyone() -- cgit v1.2.3-70-g09d2