From 89d471652644f449966a0cd944041c98dab7f66c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 07:25:24 +0100 Subject: Code deduplication --- training/common.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) (limited to 'training') diff --git a/training/common.py b/training/common.py index 0b2ae44..90cf910 100644 --- a/training/common.py +++ b/training/common.py @@ -1,10 +1,65 @@ +import math + import torch import torch.nn.functional as F from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from training.optimization import get_one_cycle_schedule + + +def get_scheduler( + id: str, + min_lr: float, + lr: float, + warmup_func: str, + annealing_func: str, + warmup_exp: int, + annealing_exp: int, + cycles: int, + warmup_epochs: int, + optimizer: torch.optim.Optimizer, + max_train_steps: int, + num_update_steps_per_epoch: int, + gradient_accumulation_steps: int, +): + warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps + + if id == "one_cycle": + min_lr = 0.04 if min_lr is None else min_lr / lr + + lr_scheduler = get_one_cycle_schedule( + optimizer=optimizer, + num_training_steps=max_train_steps * gradient_accumulation_steps, + warmup=warmup_func, + annealing=annealing_func, + warmup_exp=warmup_exp, + annealing_exp=annealing_exp, + min_lr=min_lr, + ) + elif id == "cosine_with_restarts": + cycles = cycles if cycles is not None else math.ceil( + math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) + + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=max_train_steps * gradient_accumulation_steps, + num_cycles=cycles, + ) + else: + lr_scheduler = get_scheduler_( + id, + optimizer=optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=max_train_steps * gradient_accumulation_steps, + ) + + return lr_scheduler + def generate_class_images( accelerator, -- cgit v1.2.3-54-g00ecf