diff options
Diffstat (limited to 'training/optimization.py')
-rw-r--r-- | training/optimization.py | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py index 6dee4bc..6c9a35d 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -87,6 +87,15 @@ def get_one_cycle_schedule( | |||
87 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 87 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
88 | 88 | ||
89 | 89 | ||
90 | def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): | ||
91 | def lr_lambda(base_lr: float, current_step: int): | ||
92 | return (end_lr / base_lr) ** (current_step / num_training_steps) | ||
93 | |||
94 | lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] | ||
95 | |||
96 | return LambdaLR(optimizer, lr_lambdas, last_epoch) | ||
97 | |||
98 | |||
90 | def get_scheduler( | 99 | def get_scheduler( |
91 | id: str, | 100 | id: str, |
92 | optimizer: torch.optim.Optimizer, | 101 | optimizer: torch.optim.Optimizer, |
@@ -97,6 +106,7 @@ def get_scheduler( | |||
97 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", | 106 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", |
98 | warmup_exp: int = 1, | 107 | warmup_exp: int = 1, |
99 | annealing_exp: int = 1, | 108 | annealing_exp: int = 1, |
109 | end_lr: float = 1e3, | ||
100 | cycles: int = 1, | 110 | cycles: int = 1, |
101 | train_epochs: int = 100, | 111 | train_epochs: int = 100, |
102 | warmup_epochs: int = 10, | 112 | warmup_epochs: int = 10, |
@@ -117,6 +127,15 @@ def get_scheduler( | |||
117 | annealing_exp=annealing_exp, | 127 | annealing_exp=annealing_exp, |
118 | min_lr=min_lr, | 128 | min_lr=min_lr, |
119 | ) | 129 | ) |
130 | elif id == "exponential_growth": | ||
131 | if cycles is None: | ||
132 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | ||
133 | |||
134 | lr_scheduler = get_exponential_growing_schedule( | ||
135 | optimizer=optimizer, | ||
136 | end_lr=end_lr, | ||
137 | num_training_steps=num_training_steps, | ||
138 | ) | ||
120 | elif id == "cosine_with_restarts": | 139 | elif id == "cosine_with_restarts": |
121 | if cycles is None: | 140 | if cycles is None: |
122 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 141 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) |