import math from torch.optim.lr_scheduler import LambdaLR from diffusers.utils import logging logger = logging.get_logger(__name__) def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.4, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_training_steps (`int`): The total number of training steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): thresh_up = int(num_training_steps * min(mid_point, 0.5)) if current_step < thresh_up: return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) if annealing == "linear": thresh_down = thresh_up * 2 if current_step < thresh_down: return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) return max(0.0, progress) * min_lr else: progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) return LambdaLR(optimizer, lr_lambda, last_epoch)