From 036e33fbde6bad7c48bb6f6b3d695b7908535c64 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 3 Nov 2022 17:56:08 +0100 Subject: Update --- training/optimization.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 training/optimization.py (limited to 'training') diff --git a/training/optimization.py b/training/optimization.py new file mode 100644 index 0000000..012beed --- /dev/null +++ b/training/optimization.py @@ -0,0 +1,42 @@ +import math +from torch.optim.lr_scheduler import LambdaLR + +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + + +def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.42, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + thresh_up = int(num_training_steps * min(mid_point, 0.5)) + + if current_step < thresh_up: + return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) + + if annealing == "linear": + thresh_down = thresh_up * 2 + + if current_step < thresh_down: + return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) + + return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down))) * min_lr + else: + progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) + + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) -- cgit v1.2.3-70-g09d2