diff options
Diffstat (limited to 'training/optimization.py')
-rw-r--r-- | training/optimization.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/training/optimization.py b/training/optimization.py index 5db7794..6dee4bc 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -49,8 +49,8 @@ def get_one_cycle_schedule( | |||
49 | annealing: Literal["cos", "half_cos", "linear"] = "cos", | 49 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
50 | warmup_exp: int = 1, | 50 | warmup_exp: int = 1, |
51 | annealing_exp: int = 1, | 51 | annealing_exp: int = 1, |
52 | min_lr: int = 0.04, | 52 | min_lr: float = 0.04, |
53 | mid_point: int = 0.3, | 53 | mid_point: float = 0.3, |
54 | last_epoch: int = -1 | 54 | last_epoch: int = -1 |
55 | ): | 55 | ): |
56 | if warmup == "linear": | 56 | if warmup == "linear": |
@@ -91,10 +91,10 @@ def get_scheduler( | |||
91 | id: str, | 91 | id: str, |
92 | optimizer: torch.optim.Optimizer, | 92 | optimizer: torch.optim.Optimizer, |
93 | num_training_steps_per_epoch: int, | 93 | num_training_steps_per_epoch: int, |
94 | gradient_accumulation_steps: int, | 94 | gradient_accumulation_steps: int = 1, |
95 | min_lr: float = 0.04, | 95 | min_lr: float = 0.04, |
96 | warmup_func: str = "cos", | 96 | warmup_func: Literal["cos", "linear"] = "cos", |
97 | annealing_func: str = "cos", | 97 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", |
98 | warmup_exp: int = 1, | 98 | warmup_exp: int = 1, |
99 | annealing_exp: int = 1, | 99 | annealing_exp: int = 1, |
100 | cycles: int = 1, | 100 | cycles: int = 1, |