import math from typing import NamedTuple, Literal, Callable from functools import partial import torch from torch.optim.lr_scheduler import LambdaLR class OneCyclePhase(NamedTuple): step_min: int step_max: int min: float max: float func: Callable[[float], float] def warmup_linear(progress: float): return progress def warmup_cos(exp: int, progress: float): lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) lr = lr ** (exp - (exp - 1) * progress) return lr def anneal_linear(progress: float): return 1 - progress def anneal_half_cos(exp: int, progress: float): lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) lr = lr ** (exp - (exp - 1) * progress) return lr def anneal_cos(exp: int, progress: float): lr = 0.5 * (1.0 + math.cos(math.pi * progress)) lr = lr ** (exp - (exp - 1) * progress) return lr def get_one_cycle_schedule( optimizer: torch.optim.Optimizer, num_training_steps: int, warmup: Literal["cos", "linear"] = "cos", annealing: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, annealing_exp: int = 1, min_lr: int = 0.04, mid_point: int = 0.3, last_epoch: int = -1 ): if warmup == "linear": warmup_func = warmup_linear else: warmup_func = partial(warmup_cos, warmup_exp) if annealing == "linear": anneal_func = anneal_linear elif annealing == "half_cos": anneal_func = partial(anneal_half_cos, annealing_exp) else: anneal_func = partial(anneal_cos, annealing_exp) thresh_up = int(num_training_steps * min(mid_point, 0.5)) if annealing == "linear": thresh_down = thresh_up * 2 phases = [ OneCyclePhase(0, thresh_up, min_lr, 1, warmup_func), OneCyclePhase(thresh_up, thresh_down, min_lr, 1, anneal_func), OneCyclePhase(thresh_down, num_training_steps, 0, min_lr, anneal_func), ] else: phases = [ OneCyclePhase(0, thresh_up, min_lr, 1, warmup_func), OneCyclePhase(thresh_up, num_training_steps, 0, 1, anneal_func), ] def lr_lambda(current_step: int): phase = [p for p in phases if current_step >= p.step_min][-1] return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) return LambdaLR(optimizer, lr_lambda, last_epoch)