From 4d3d318a4168ef79847737cef2c0ad8a4dafd3e7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 29 Dec 2022 09:00:19 +0100 Subject: Training improvements --- training/optimization.py | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) (limited to 'training/optimization.py') diff --git a/training/optimization.py b/training/optimization.py index a0c8673..dfee2b5 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -1,4 +1,7 @@ import math +from typing import Literal + +import torch from torch.optim.lr_scheduler import LambdaLR from diffusers.utils import logging @@ -6,41 +9,41 @@ from diffusers.utils import logging logger = logging.get_logger(__name__) -def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.04, mid_point=0.3, last_epoch=-1): - """ - Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after - a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_training_steps (`int`): - The total number of training steps. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - +def get_one_cycle_schedule( + optimizer: torch.optim.Optimizer, + num_training_steps: int, + warmup: Literal["cos", "linear"] = "cos", + annealing: Literal["cos", "half_cos", "linear"] = "cos", + min_lr: int = 0.04, + mid_point: int = 0.3, + last_epoch: int = -1 +): def lr_lambda(current_step: int): thresh_up = int(num_training_steps * min(mid_point, 0.5)) if current_step < thresh_up: - return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) + progress = float(current_step) / float(max(1, thresh_up)) + + if warmup == "linear": + return min_lr + progress * (1 - min_lr) + + return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) if annealing == "linear": thresh_down = thresh_up * 2 if current_step < thresh_down: - return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) + progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) + return min_lr + progress * (1 - min_lr) progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) - return max(0.0, progress) * min_lr + return progress * min_lr progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) if annealing == "half_cos": - return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) + return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) + return 0.5 * (1.0 + math.cos(math.pi * progress)) return LambdaLR(optimizer, lr_lambda, last_epoch) -- cgit v1.2.3-54-g00ecf