diff options
-rw-r--r-- | train_ti.py | 2 | ||||
-rw-r--r-- | training/optimization.py | 92 |
2 files changed, 63 insertions, 31 deletions
diff --git a/train_ti.py b/train_ti.py index 870bd40..775b918 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -250,7 +250,7 @@ def parse_args(): | |||
250 | parser.add_argument( | 250 | parser.add_argument( |
251 | "--lr_annealing_exp", | 251 | "--lr_annealing_exp", |
252 | type=int, | 252 | type=int, |
253 | default=2, | 253 | default=1, |
254 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 254 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
255 | ) | 255 | ) |
256 | parser.add_argument( | 256 | parser.add_argument( |
diff --git a/training/optimization.py b/training/optimization.py index 725599b..cb7f088 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -1,5 +1,6 @@ | |||
1 | import math | 1 | import math |
2 | from typing import Literal | 2 | from typing import NamedTuple, Literal, Callable |
3 | from functools import partial | ||
3 | 4 | ||
4 | import torch | 5 | import torch |
5 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
@@ -9,6 +10,40 @@ from diffusers.utils import logging | |||
9 | logger = logging.get_logger(__name__) | 10 | logger = logging.get_logger(__name__) |
10 | 11 | ||
11 | 12 | ||
13 | class OneCyclePhase(NamedTuple): | ||
14 | step_min: int | ||
15 | step_max: int | ||
16 | min: float | ||
17 | max: float | ||
18 | func: Callable[[float], float] | ||
19 | |||
20 | |||
21 | def warmup_linear(progress: float): | ||
22 | return progress | ||
23 | |||
24 | |||
25 | def warmup_cos(exp: int, progress: float): | ||
26 | lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | ||
27 | lr = lr ** (exp - (exp - 1) * progress) | ||
28 | return lr | ||
29 | |||
30 | |||
31 | def anneal_linear(progress: float): | ||
32 | return 1 - progress | ||
33 | |||
34 | |||
35 | def anneal_half_cos(exp: int, progress: float): | ||
36 | lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) | ||
37 | lr = lr ** (exp - (exp - 1) * progress) | ||
38 | return lr | ||
39 | |||
40 | |||
41 | def anneal_cos(exp: int, progress: float): | ||
42 | lr = 0.5 * (1.0 + math.cos(math.pi * progress)) | ||
43 | lr = lr ** (exp - (exp - 1) * progress) | ||
44 | return lr | ||
45 | |||
46 | |||
12 | def get_one_cycle_schedule( | 47 | def get_one_cycle_schedule( |
13 | optimizer: torch.optim.Optimizer, | 48 | optimizer: torch.optim.Optimizer, |
14 | num_training_steps: int, | 49 | num_training_steps: int, |
@@ -20,38 +55,35 @@ def get_one_cycle_schedule( | |||
20 | mid_point: int = 0.3, | 55 | mid_point: int = 0.3, |
21 | last_epoch: int = -1 | 56 | last_epoch: int = -1 |
22 | ): | 57 | ): |
23 | def lr_lambda(current_step: int): | 58 | if warmup == "linear": |
24 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) | 59 | warmup_func = warmup_linear |
25 | 60 | else: | |
26 | if current_step < thresh_up: | 61 | warmup_func = partial(warmup_cos, warmup_exp) |
27 | progress = float(current_step) / float(max(1, thresh_up)) | ||
28 | |||
29 | if warmup == "linear": | ||
30 | return min_lr + progress * (1 - min_lr) | ||
31 | 62 | ||
32 | lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | 63 | if annealing == "linear": |
33 | lr = lr ** (warmup_exp - (warmup_exp - 1) * progress) | 64 | anneal_func = anneal_linear |
34 | return min_lr + lr * (1 - min_lr) | 65 | elif annealing == "half_cos": |
66 | anneal_func = partial(anneal_half_cos, annealing_exp) | ||
67 | else: | ||
68 | anneal_func = partial(anneal_cos, annealing_exp) | ||
35 | 69 | ||
36 | if annealing == "linear": | 70 | thresh_up = int(num_training_steps * min(mid_point, 0.5)) |
37 | thresh_down = thresh_up * 2 | ||
38 | 71 | ||
39 | if current_step < thresh_down: | 72 | if annealing == "linear": |
40 | progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) | 73 | thresh_down = thresh_up * 2 |
41 | return min_lr + progress * (1 - min_lr) | 74 | phases = [ |
75 | OneCyclePhase(0, thresh_up, min_lr, 1, warmup_func), | ||
76 | OneCyclePhase(thresh_up, thresh_down, min_lr, 1, anneal_func), | ||
77 | OneCyclePhase(thresh_down, num_training_steps, 0, min_lr, anneal_func), | ||
78 | ] | ||
79 | else: | ||
80 | phases = [ | ||
81 | OneCyclePhase(0, thresh_up, 0, 1, warmup_func), | ||
82 | OneCyclePhase(thresh_up, num_training_steps, 0, 1, anneal_func), | ||
83 | ] | ||
42 | 84 | ||
43 | progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) | 85 | def lr_lambda(current_step: int): |
44 | return progress * min_lr | 86 | phase = [p for p in phases if current_step >= p.step_min][-1] |
45 | 87 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) | |
46 | progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) | ||
47 | |||
48 | if annealing == "half_cos": | ||
49 | lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) | ||
50 | lr = lr ** (annealing_exp - (annealing_exp - 1) * progress) | ||
51 | return lr | ||
52 | |||
53 | lr = 0.5 * (1.0 + math.cos(math.pi * progress)) | ||
54 | lr = lr ** (annealing_exp - (annealing_exp - 1) * progress) | ||
55 | return lr | ||
56 | 88 | ||
57 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 89 | return LambdaLR(optimizer, lr_lambda, last_epoch) |