From 3e7fbb7dce321435bbbb81361debfbc499bf9231 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 22:25:30 +0100 Subject: Reverted modularization mostly --- training/optimization.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) (limited to 'training/optimization.py') diff --git a/training/optimization.py b/training/optimization.py index dd84f9c..5db7794 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -5,6 +5,8 @@ from functools import partial import torch from torch.optim.lr_scheduler import LambdaLR +from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup + class OneCyclePhase(NamedTuple): step_min: int @@ -83,3 +85,54 @@ def get_one_cycle_schedule( return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_scheduler( + id: str, + optimizer: torch.optim.Optimizer, + num_training_steps_per_epoch: int, + gradient_accumulation_steps: int, + min_lr: float = 0.04, + warmup_func: str = "cos", + annealing_func: str = "cos", + warmup_exp: int = 1, + annealing_exp: int = 1, + cycles: int = 1, + train_epochs: int = 100, + warmup_epochs: int = 10, +): + num_training_steps_per_epoch = math.ceil( + num_training_steps_per_epoch / gradient_accumulation_steps + ) * gradient_accumulation_steps + num_training_steps = train_epochs * num_training_steps_per_epoch + num_warmup_steps = warmup_epochs * num_training_steps_per_epoch + + if id == "one_cycle": + lr_scheduler = get_one_cycle_schedule( + optimizer=optimizer, + num_training_steps=num_training_steps, + warmup=warmup_func, + annealing=annealing_func, + warmup_exp=warmup_exp, + annealing_exp=annealing_exp, + min_lr=min_lr, + ) + elif id == "cosine_with_restarts": + if cycles is None: + cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=cycles, + ) + else: + lr_scheduler = get_scheduler_( + id, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + + return lr_scheduler -- cgit v1.2.3-54-g00ecf