import math from typing import Literal import torch from torch.optim.lr_scheduler import LambdaLR from diffusers.utils import logging logger = logging.get_logger(__name__) def get_one_cycle_schedule( optimizer: torch.optim.Optimizer, num_training_steps: int, warmup: Literal["cos", "linear"] = "cos", annealing: Literal["cos", "half_cos", "linear"] = "cos", min_lr: int = 0.04, mid_point: int = 0.3, last_epoch: int = -1 ): def lr_lambda(current_step: int): thresh_up = int(num_training_steps * min(mid_point, 0.5)) if current_step < thresh_up: progress = float(current_step) / float(max(1, thresh_up)) if warmup == "linear": return min_lr + progress * (1 - min_lr) return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) if annealing == "linear": thresh_down = thresh_up * 2 if current_step < thresh_down: progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) return min_lr + progress * (1 - min_lr) progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) return progress * min_lr progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) if annealing == "half_cos": return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) return 0.5 * (1.0 + math.cos(math.pi * progress)) return LambdaLR(optimizer, lr_lambda, last_epoch)