From f87d9fdf541b0282249ddde1dc0302317350f998 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 29 Dec 2022 15:28:02 +0100 Subject: Update --- training/optimization.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'training/optimization.py') diff --git a/training/optimization.py b/training/optimization.py index dfee2b5..3340544 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -14,6 +14,8 @@ def get_one_cycle_schedule( num_training_steps: int, warmup: Literal["cos", "linear"] = "cos", annealing: Literal["cos", "half_cos", "linear"] = "cos", + warmup_exp: int = 1, + annealing_exp: int = 2, min_lr: int = 0.04, mid_point: int = 0.3, last_epoch: int = -1 @@ -27,7 +29,9 @@ def get_one_cycle_schedule( if warmup == "linear": return min_lr + progress * (1 - min_lr) - return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) + lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) + lr = lr ** warmup_exp + return min_lr + lr * (1 - min_lr) if annealing == "linear": thresh_down = thresh_up * 2 @@ -42,8 +46,12 @@ def get_one_cycle_schedule( progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) if annealing == "half_cos": - return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) + lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) + lr = lr ** annealing_exp + return lr - return 0.5 * (1.0 + math.cos(math.pi * progress)) + lr = 0.5 * (1.0 + math.cos(math.pi * progress)) + lr = lr ** annealing_exp + return lr return LambdaLR(optimizer, lr_lambda, last_epoch) -- cgit v1.2.3-54-g00ecf