diff options
Diffstat (limited to 'training/optimization.py')
-rw-r--r-- | training/optimization.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/training/optimization.py b/training/optimization.py index dfee2b5..3340544 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -14,6 +14,8 @@ def get_one_cycle_schedule( | |||
14 | num_training_steps: int, | 14 | num_training_steps: int, |
15 | warmup: Literal["cos", "linear"] = "cos", | 15 | warmup: Literal["cos", "linear"] = "cos", |
16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", | 16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
17 | warmup_exp: int = 1, | ||
18 | annealing_exp: int = 2, | ||
17 | min_lr: int = 0.04, | 19 | min_lr: int = 0.04, |
18 | mid_point: int = 0.3, | 20 | mid_point: int = 0.3, |
19 | last_epoch: int = -1 | 21 | last_epoch: int = -1 |
@@ -27,7 +29,9 @@ def get_one_cycle_schedule( | |||
27 | if warmup == "linear": | 29 | if warmup == "linear": |
28 | return min_lr + progress * (1 - min_lr) | 30 | return min_lr + progress * (1 - min_lr) |
29 | 31 | ||
30 | return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | 32 | lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) |
33 | lr = lr ** warmup_exp | ||
34 | return min_lr + lr * (1 - min_lr) | ||
31 | 35 | ||
32 | if annealing == "linear": | 36 | if annealing == "linear": |
33 | thresh_down = thresh_up * 2 | 37 | thresh_down = thresh_up * 2 |
@@ -42,8 +46,12 @@ def get_one_cycle_schedule( | |||
42 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | 46 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) |
43 | 47 | ||
44 | if annealing == "half_cos": | 48 | if annealing == "half_cos": |
45 | return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) | 49 | lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) |
50 | lr = lr ** annealing_exp | ||
51 | return lr | ||
46 | 52 | ||
47 | return 0.5 * (1.0 + math.cos(math.pi * progress)) | 53 | lr = 0.5 * (1.0 + math.cos(math.pi * progress)) |
54 | lr = lr ** annealing_exp | ||
55 | return lr | ||
48 | 56 | ||
49 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 57 | return LambdaLR(optimizer, lr_lambda, last_epoch) |