From c24a24c77a825e71de2e67c1515c84d2b77701fa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 2 Jan 2023 11:27:38 +0100 Subject: Improved one cycle scheduler --- train_ti.py | 2 +- training/optimization.py | 92 ++++++++++++++++++++++++++++++++---------------- 2 files changed, 63 insertions(+), 31 deletions(-) diff --git a/train_ti.py b/train_ti.py index 870bd40..775b918 100644 --- a/train_ti.py +++ b/train_ti.py @@ -250,7 +250,7 @@ def parse_args(): parser.add_argument( "--lr_annealing_exp", type=int, - default=2, + default=1, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) parser.add_argument( diff --git a/training/optimization.py b/training/optimization.py index 725599b..cb7f088 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -1,5 +1,6 @@ import math -from typing import Literal +from typing import NamedTuple, Literal, Callable +from functools import partial import torch from torch.optim.lr_scheduler import LambdaLR @@ -9,6 +10,40 @@ from diffusers.utils import logging logger = logging.get_logger(__name__) +class OneCyclePhase(NamedTuple): + step_min: int + step_max: int + min: float + max: float + func: Callable[[float], float] + + +def warmup_linear(progress: float): + return progress + + +def warmup_cos(exp: int, progress: float): + lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) + lr = lr ** (exp - (exp - 1) * progress) + return lr + + +def anneal_linear(progress: float): + return 1 - progress + + +def anneal_half_cos(exp: int, progress: float): + lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) + lr = lr ** (exp - (exp - 1) * progress) + return lr + + +def anneal_cos(exp: int, progress: float): + lr = 0.5 * (1.0 + math.cos(math.pi * progress)) + lr = lr ** (exp - (exp - 1) * progress) + return lr + + def get_one_cycle_schedule( optimizer: torch.optim.Optimizer, num_training_steps: int, @@ -20,38 +55,35 @@ def get_one_cycle_schedule( mid_point: int = 0.3, last_epoch: int = -1 ): - def lr_lambda(current_step: int): - thresh_up = int(num_training_steps * min(mid_point, 0.5)) - - if current_step < thresh_up: - progress = float(current_step) / float(max(1, thresh_up)) - - if warmup == "linear": - return min_lr + progress * (1 - min_lr) + if warmup == "linear": + warmup_func = warmup_linear + else: + warmup_func = partial(warmup_cos, warmup_exp) - lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) - lr = lr ** (warmup_exp - (warmup_exp - 1) * progress) - return min_lr + lr * (1 - min_lr) + if annealing == "linear": + anneal_func = anneal_linear + elif annealing == "half_cos": + anneal_func = partial(anneal_half_cos, annealing_exp) + else: + anneal_func = partial(anneal_cos, annealing_exp) - if annealing == "linear": - thresh_down = thresh_up * 2 + thresh_up = int(num_training_steps * min(mid_point, 0.5)) - if current_step < thresh_down: - progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) - return min_lr + progress * (1 - min_lr) + if annealing == "linear": + thresh_down = thresh_up * 2 + phases = [ + OneCyclePhase(0, thresh_up, min_lr, 1, warmup_func), + OneCyclePhase(thresh_up, thresh_down, min_lr, 1, anneal_func), + OneCyclePhase(thresh_down, num_training_steps, 0, min_lr, anneal_func), + ] + else: + phases = [ + OneCyclePhase(0, thresh_up, 0, 1, warmup_func), + OneCyclePhase(thresh_up, num_training_steps, 0, 1, anneal_func), + ] - progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) - return progress * min_lr - - progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) - - if annealing == "half_cos": - lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) - lr = lr ** (annealing_exp - (annealing_exp - 1) * progress) - return lr - - lr = 0.5 * (1.0 + math.cos(math.pi * progress)) - lr = lr ** (annealing_exp - (annealing_exp - 1) * progress) - return lr + def lr_lambda(current_step: int): + phase = [p for p in phases if current_step >= p.step_min][-1] + return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) return LambdaLR(optimizer, lr_lambda, last_epoch) -- cgit v1.2.3-70-g09d2