From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- training/optimization.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) (limited to 'training/optimization.py') diff --git a/training/optimization.py b/training/optimization.py index d22a900..55531bf 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -5,7 +5,10 @@ from functools import partial import torch from torch.optim.lr_scheduler import LambdaLR -from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup +from diffusers.optimization import ( + get_scheduler as get_scheduler_, + get_cosine_with_hard_restarts_schedule_with_warmup, +) from transformers.optimization import get_adafactor_schedule @@ -52,7 +55,7 @@ def get_one_cycle_schedule( annealing_exp: int = 1, min_lr: float = 0.04, mid_point: float = 0.3, - last_epoch: int = -1 + last_epoch: int = -1, ): if warmup == "linear": warmup_func = warmup_linear @@ -83,12 +86,16 @@ def get_one_cycle_schedule( def lr_lambda(current_step: int): phase = [p for p in phases if current_step >= p.step_min][-1] - return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) + return phase.min + phase.func( + (current_step - phase.step_min) / (phase.step_max - phase.step_min) + ) * (phase.max - phase.min) return LambdaLR(optimizer, lr_lambda, last_epoch) -def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): +def get_exponential_growing_schedule( + optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1 +): def lr_lambda(base_lr: float, current_step: int): return (end_lr / base_lr) ** (current_step / num_training_steps) @@ -132,7 +139,14 @@ def get_scheduler( ) elif id == "exponential_growth": if cycles is None: - cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + cycles = math.ceil( + math.sqrt( + ( + (num_training_steps - num_warmup_steps) + / num_training_steps_per_epoch + ) + ) + ) lr_scheduler = get_exponential_growing_schedule( optimizer=optimizer, @@ -141,7 +155,14 @@ def get_scheduler( ) elif id == "cosine_with_restarts": if cycles is None: - cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + cycles = math.ceil( + math.sqrt( + ( + (num_training_steps - num_warmup_steps) + / num_training_steps_per_epoch + ) + ) + ) lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, @@ -150,10 +171,7 @@ def get_scheduler( num_cycles=cycles, ) elif id == "adafactor": - lr_scheduler = get_adafactor_schedule( - optimizer, - initial_lr=min_lr - ) + lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr) else: lr_scheduler = get_scheduler_( id, -- cgit v1.2.3-54-g00ecf