diff options
Diffstat (limited to 'training/optimization.py')
-rw-r--r-- | training/optimization.py | 43 |
1 files changed, 23 insertions, 20 deletions
diff --git a/training/optimization.py b/training/optimization.py index a0c8673..dfee2b5 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -1,4 +1,7 @@ | |||
1 | import math | 1 | import math |
2 | from typing import Literal | ||
3 | |||
4 | import torch | ||
2 | from torch.optim.lr_scheduler import LambdaLR | 5 | from torch.optim.lr_scheduler import LambdaLR |
3 | 6 | ||
4 | from diffusers.utils import logging | 7 | from diffusers.utils import logging |
@@ -6,41 +9,41 @@ from diffusers.utils import logging | |||
6 | logger = logging.get_logger(__name__) | 9 | logger = logging.get_logger(__name__) |
7 | 10 | ||
8 | 11 | ||
9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.04, mid_point=0.3, last_epoch=-1): | 12 | def get_one_cycle_schedule( |
10 | """ | 13 | optimizer: torch.optim.Optimizer, |
11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 14 | num_training_steps: int, |
12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 15 | warmup: Literal["cos", "linear"] = "cos", |
13 | Args: | 16 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
14 | optimizer ([`~torch.optim.Optimizer`]): | 17 | min_lr: int = 0.04, |
15 | The optimizer for which to schedule the learning rate. | 18 | mid_point: int = 0.3, |
16 | num_training_steps (`int`): | 19 | last_epoch: int = -1 |
17 | The total number of training steps. | 20 | ): |
18 | last_epoch (`int`, *optional*, defaults to -1): | ||
19 | The index of the last epoch when resuming training. | ||
20 | Return: | ||
21 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | ||
22 | """ | ||
23 | |||
24 | def lr_lambda(current_step: int): | 21 | def lr_lambda(current_step: int): |
25 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) | 22 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) |
26 | 23 | ||
27 | if current_step < thresh_up: | 24 | if current_step < thresh_up: |
28 | return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) | 25 | progress = float(current_step) / float(max(1, thresh_up)) |
26 | |||
27 | if warmup == "linear": | ||
28 | return min_lr + progress * (1 - min_lr) | ||
29 | |||
30 | return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | ||
29 | 31 | ||
30 | if annealing == "linear": | 32 | if annealing == "linear": |
31 | thresh_down = thresh_up * 2 | 33 | thresh_down = thresh_up * 2 |
32 | 34 | ||
33 | if current_step < thresh_down: | 35 | if current_step < thresh_down: |
34 | return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) | 36 | progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) |
37 | return min_lr + progress * (1 - min_lr) | ||
35 | 38 | ||
36 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) | 39 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) |
37 | return max(0.0, progress) * min_lr | 40 | return progress * min_lr |
38 | 41 | ||
39 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | 42 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) |
40 | 43 | ||
41 | if annealing == "half_cos": | 44 | if annealing == "half_cos": |
42 | return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) | 45 | return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) |
43 | 46 | ||
44 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | 47 | return 0.5 * (1.0 + math.cos(math.pi * progress)) |
45 | 48 | ||
46 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 49 | return LambdaLR(optimizer, lr_lambda, last_epoch) |