summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py2
-rw-r--r--training/optimization.py92
2 files changed, 63 insertions, 31 deletions
diff --git a/train_ti.py b/train_ti.py
index 870bd40..775b918 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -250,7 +250,7 @@ def parse_args():
250 parser.add_argument( 250 parser.add_argument(
251 "--lr_annealing_exp", 251 "--lr_annealing_exp",
252 type=int, 252 type=int,
253 default=2, 253 default=1,
254 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 254 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
255 ) 255 )
256 parser.add_argument( 256 parser.add_argument(
diff --git a/training/optimization.py b/training/optimization.py
index 725599b..cb7f088 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -1,5 +1,6 @@
1import math 1import math
2from typing import Literal 2from typing import NamedTuple, Literal, Callable
3from functools import partial
3 4
4import torch 5import torch
5from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
@@ -9,6 +10,40 @@ from diffusers.utils import logging
9logger = logging.get_logger(__name__) 10logger = logging.get_logger(__name__)
10 11
11 12
13class OneCyclePhase(NamedTuple):
14 step_min: int
15 step_max: int
16 min: float
17 max: float
18 func: Callable[[float], float]
19
20
21def warmup_linear(progress: float):
22 return progress
23
24
25def warmup_cos(exp: int, progress: float):
26 lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress)))
27 lr = lr ** (exp - (exp - 1) * progress)
28 return lr
29
30
31def anneal_linear(progress: float):
32 return 1 - progress
33
34
35def anneal_half_cos(exp: int, progress: float):
36 lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))
37 lr = lr ** (exp - (exp - 1) * progress)
38 return lr
39
40
41def anneal_cos(exp: int, progress: float):
42 lr = 0.5 * (1.0 + math.cos(math.pi * progress))
43 lr = lr ** (exp - (exp - 1) * progress)
44 return lr
45
46
12def get_one_cycle_schedule( 47def get_one_cycle_schedule(
13 optimizer: torch.optim.Optimizer, 48 optimizer: torch.optim.Optimizer,
14 num_training_steps: int, 49 num_training_steps: int,
@@ -20,38 +55,35 @@ def get_one_cycle_schedule(
20 mid_point: int = 0.3, 55 mid_point: int = 0.3,
21 last_epoch: int = -1 56 last_epoch: int = -1
22): 57):
23 def lr_lambda(current_step: int): 58 if warmup == "linear":
24 thresh_up = int(num_training_steps * min(mid_point, 0.5)) 59 warmup_func = warmup_linear
25 60 else:
26 if current_step < thresh_up: 61 warmup_func = partial(warmup_cos, warmup_exp)
27 progress = float(current_step) / float(max(1, thresh_up))
28
29 if warmup == "linear":
30 return min_lr + progress * (1 - min_lr)
31 62
32 lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) 63 if annealing == "linear":
33 lr = lr ** (warmup_exp - (warmup_exp - 1) * progress) 64 anneal_func = anneal_linear
34 return min_lr + lr * (1 - min_lr) 65 elif annealing == "half_cos":
66 anneal_func = partial(anneal_half_cos, annealing_exp)
67 else:
68 anneal_func = partial(anneal_cos, annealing_exp)
35 69
36 if annealing == "linear": 70 thresh_up = int(num_training_steps * min(mid_point, 0.5))
37 thresh_down = thresh_up * 2
38 71
39 if current_step < thresh_down: 72 if annealing == "linear":
40 progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) 73 thresh_down = thresh_up * 2
41 return min_lr + progress * (1 - min_lr) 74 phases = [
75 OneCyclePhase(0, thresh_up, min_lr, 1, warmup_func),
76 OneCyclePhase(thresh_up, thresh_down, min_lr, 1, anneal_func),
77 OneCyclePhase(thresh_down, num_training_steps, 0, min_lr, anneal_func),
78 ]
79 else:
80 phases = [
81 OneCyclePhase(0, thresh_up, 0, 1, warmup_func),
82 OneCyclePhase(thresh_up, num_training_steps, 0, 1, anneal_func),
83 ]
42 84
43 progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) 85 def lr_lambda(current_step: int):
44 return progress * min_lr 86 phase = [p for p in phases if current_step >= p.step_min][-1]
45 87 return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min)
46 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up))
47
48 if annealing == "half_cos":
49 lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))
50 lr = lr ** (annealing_exp - (annealing_exp - 1) * progress)
51 return lr
52
53 lr = 0.5 * (1.0 + math.cos(math.pi * progress))
54 lr = lr ** (annealing_exp - (annealing_exp - 1) * progress)
55 return lr
56 88
57 return LambdaLR(optimizer, lr_lambda, last_epoch) 89 return LambdaLR(optimizer, lr_lambda, last_epoch)