import math from typing import NamedTuple, Literal, Callable from functools import partial import torch from torch.optim.lr_scheduler import LambdaLR from diffusers.optimization import ( get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup, ) from transformers.optimization import get_adafactor_schedule class OneCyclePhase(NamedTuple): step_min: int step_max: int min: float max: float func: Callable[[float], float] def warmup_linear(progress: float): return progress def warmup_cos(exp: int, progress: float): lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) lr = lr ** (exp - (exp - 1) * progress) return lr def anneal_linear(progress: float): return 1 - progress def anneal_half_cos(exp: int, progress: float): lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) lr = lr ** (exp - (exp - 1) * progress) return lr def anneal_cos(exp: int, progress: float): lr = 0.5 * (1.0 + math.cos(math.pi * progress)) lr = lr ** (exp - (exp - 1) * progress) return lr def get_one_cycle_schedule( optimizer: torch.optim.Optimizer, num_training_steps: int, warmup: Literal["cos", "linear"] = "cos", annealing: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, annealing_exp: int = 1, min_lr: float = 0.04, mid_point: float = 0.3, last_epoch: int = -1, ): if warmup == "linear": warmup_func = warmup_linear else: warmup_func = partial(warmup_cos, warmup_exp) if annealing == "linear": anneal_func = anneal_linear elif annealing == "half_cos": anneal_func = partial(anneal_half_cos, annealing_exp) else: anneal_func = partial(anneal_cos, annealing_exp) thresh_up = int(num_training_steps * min(mid_point, 0.5)) if annealing == "linear": thresh_down = thresh_up * 2 phases = [ OneCyclePhase(0, thresh_up, min_lr, 1, warmup_func), OneCyclePhase(thresh_up, thresh_down, min_lr, 1, anneal_func), OneCyclePhase(thresh_down, num_training_steps, 0, min_lr, anneal_func), ] else: phases = [ OneCyclePhase(0, thresh_up, min_lr, 1, warmup_func), OneCyclePhase(thresh_up, num_training_steps, 0, 1, anneal_func), ] def lr_lambda(current_step: int): phase = [p for p in phases if current_step >= p.step_min][-1] return phase.min + phase.func( (current_step - phase.step_min) / (phase.step_max - phase.step_min) ) * (phase.max - phase.min) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_exponential_growing_schedule( optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1 ): def lr_lambda(base_lr: float, current_step: int): return (end_lr / base_lr) ** (current_step / num_training_steps) lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] return LambdaLR(optimizer, lr_lambdas, last_epoch) def get_scheduler( id: str, optimizer: torch.optim.Optimizer, num_training_steps_per_epoch: int, gradient_accumulation_steps: int = 1, min_lr: float = 0.04, mid_point: float = 0.3, warmup_func: Literal["cos", "linear"] = "cos", annealing_func: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, annealing_exp: int = 1, end_lr: float = 1e3, cycles: int = 1, train_epochs: int = 100, warmup_epochs: int = 10, ): num_training_steps_per_epoch = math.ceil( num_training_steps_per_epoch / gradient_accumulation_steps ) # * gradient_accumulation_steps num_training_steps = train_epochs * num_training_steps_per_epoch num_warmup_steps = warmup_epochs * num_training_steps_per_epoch if id == "one_cycle": lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=num_training_steps, warmup=warmup_func, annealing=annealing_func, warmup_exp=warmup_exp, annealing_exp=annealing_exp, min_lr=min_lr, mid_point=mid_point, ) elif id == "exponential_growth": if cycles is None: cycles = math.ceil( math.sqrt( ( (num_training_steps - num_warmup_steps) / num_training_steps_per_epoch ) ) ) lr_scheduler = get_exponential_growing_schedule( optimizer=optimizer, end_lr=end_lr, num_training_steps=num_training_steps, ) elif id == "cosine_with_restarts": if cycles is None: cycles = math.ceil( math.sqrt( ( (num_training_steps - num_warmup_steps) / num_training_steps_per_epoch ) ) ) lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=cycles, ) elif id == "adafactor": lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr) else: lr_scheduler = get_scheduler_( id, optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, ) return lr_scheduler