From c11062dd765d77f8e78ff5403541fe43085ad763 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 14:12:22 +0100 Subject: Simplified step calculations --- training/common.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) (limited to 'training') diff --git a/training/common.py b/training/common.py index 842ac07..180396e 100644 --- a/training/common.py +++ b/training/common.py @@ -36,21 +36,24 @@ def get_scheduler( warmup_exp: int, annealing_exp: int, cycles: int, + train_epochs: int, warmup_epochs: int, optimizer: torch.optim.Optimizer, - num_train_epochs: int, - num_update_steps_per_epoch: int, + num_training_steps_per_epoch: int, gradient_accumulation_steps: int, ): - num_train_steps = num_train_epochs * num_update_steps_per_epoch - warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps + num_training_steps_per_epoch = math.ceil( + num_training_steps_per_epoch / gradient_accumulation_steps + ) * gradient_accumulation_steps + num_training_steps = train_epochs * num_training_steps_per_epoch + num_warmup_steps = warmup_epochs * num_training_steps_per_epoch if id == "one_cycle": min_lr = 0.04 if min_lr is None else min_lr / lr lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, - num_training_steps=num_train_steps * gradient_accumulation_steps, + num_training_steps=num_training_steps, warmup=warmup_func, annealing=annealing_func, warmup_exp=warmup_exp, @@ -58,21 +61,21 @@ def get_scheduler( min_lr=min_lr, ) elif id == "cosine_with_restarts": - cycles = cycles if cycles is not None else math.ceil( - math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) + if cycles is None: + cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=num_train_steps * gradient_accumulation_steps, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, num_cycles=cycles, ) else: lr_scheduler = get_scheduler_( id, optimizer=optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=num_train_steps * gradient_accumulation_steps, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, ) return lr_scheduler @@ -135,7 +138,7 @@ def loss_step( noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, - num_class_images: int, + with_prior: bool, prior_loss_weight: float, seed: int, step: int, @@ -184,7 +187,7 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if num_class_images != 0: + if with_prior: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -226,11 +229,10 @@ def train_loop( on_after_optimize: Callable[[float], None] = noop, on_eval: Callable[[], _GeneratorContextManager] = nullcontext ): - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) - num_train_steps = num_epochs * num_update_steps_per_epoch - + num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) - num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) + + num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs global_step = 0 @@ -244,14 +246,14 @@ def train_loop( max_acc_val = 0.0 local_progress_bar = tqdm( - range(num_update_steps_per_epoch + num_val_steps_per_epoch), + range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( - range(num_train_steps + num_val_steps), + range(num_training_steps + num_val_steps), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) @@ -309,7 +311,7 @@ def train_loop( local_progress_bar.set_postfix(**logs) - if global_step >= num_train_steps: + if global_step >= num_training_steps: break accelerator.wait_for_everyone() -- cgit v1.2.3-54-g00ecf