summaryrefslogtreecommitdiffstats
path: root/training/optimization.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-29 15:28:02 +0100
committerVolpeon <git@volpeon.ink>2022-12-29 15:28:02 +0100
commitf87d9fdf541b0282249ddde1dc0302317350f998 (patch)
treea27f4319d90098f026784711ffd1a415fa561def /training/optimization.py
parentTraining improvements (diff)
downloadtextual-inversion-diff-f87d9fdf541b0282249ddde1dc0302317350f998.tar.gz
textual-inversion-diff-f87d9fdf541b0282249ddde1dc0302317350f998.tar.bz2
textual-inversion-diff-f87d9fdf541b0282249ddde1dc0302317350f998.zip
Update
Diffstat (limited to 'training/optimization.py')
-rw-r--r--training/optimization.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/training/optimization.py b/training/optimization.py
index dfee2b5..3340544 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -14,6 +14,8 @@ def get_one_cycle_schedule(
14 num_training_steps: int, 14 num_training_steps: int,
15 warmup: Literal["cos", "linear"] = "cos", 15 warmup: Literal["cos", "linear"] = "cos",
16 annealing: Literal["cos", "half_cos", "linear"] = "cos", 16 annealing: Literal["cos", "half_cos", "linear"] = "cos",
17 warmup_exp: int = 1,
18 annealing_exp: int = 2,
17 min_lr: int = 0.04, 19 min_lr: int = 0.04,
18 mid_point: int = 0.3, 20 mid_point: int = 0.3,
19 last_epoch: int = -1 21 last_epoch: int = -1
@@ -27,7 +29,9 @@ def get_one_cycle_schedule(
27 if warmup == "linear": 29 if warmup == "linear":
28 return min_lr + progress * (1 - min_lr) 30 return min_lr + progress * (1 - min_lr)
29 31
30 return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) 32 lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress)))
33 lr = lr ** warmup_exp
34 return min_lr + lr * (1 - min_lr)
31 35
32 if annealing == "linear": 36 if annealing == "linear":
33 thresh_down = thresh_up * 2 37 thresh_down = thresh_up * 2
@@ -42,8 +46,12 @@ def get_one_cycle_schedule(
42 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) 46 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up))
43 47
44 if annealing == "half_cos": 48 if annealing == "half_cos":
45 return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) 49 lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))
50 lr = lr ** annealing_exp
51 return lr
46 52
47 return 0.5 * (1.0 + math.cos(math.pi * progress)) 53 lr = 0.5 * (1.0 + math.cos(math.pi * progress))
54 lr = lr ** annealing_exp
55 return lr
48 56
49 return LambdaLR(optimizer, lr_lambda, last_epoch) 57 return LambdaLR(optimizer, lr_lambda, last_epoch)