summaryrefslogtreecommitdiffstats
path: root/training/optimization.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-29 09:00:19 +0100
committerVolpeon <git@volpeon.ink>2022-12-29 09:00:19 +0100
commit4d3d318a4168ef79847737cef2c0ad8a4dafd3e7 (patch)
tree967e2c1ee6e2c29b9b6ffaff3e8978f4a43a529d /training/optimization.py
parentUpdated 1-cycle scheduler (diff)
downloadtextual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.gz
textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.bz2
textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.zip
Training improvements
Diffstat (limited to 'training/optimization.py')
-rw-r--r--training/optimization.py43
1 files changed, 23 insertions, 20 deletions
diff --git a/training/optimization.py b/training/optimization.py
index a0c8673..dfee2b5 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -1,4 +1,7 @@
1import math 1import math
2from typing import Literal
3
4import torch
2from torch.optim.lr_scheduler import LambdaLR 5from torch.optim.lr_scheduler import LambdaLR
3 6
4from diffusers.utils import logging 7from diffusers.utils import logging
@@ -6,41 +9,41 @@ from diffusers.utils import logging
6logger = logging.get_logger(__name__) 9logger = logging.get_logger(__name__)
7 10
8 11
9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.04, mid_point=0.3, last_epoch=-1): 12def get_one_cycle_schedule(
10 """ 13 optimizer: torch.optim.Optimizer,
11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 14 num_training_steps: int,
12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 15 warmup: Literal["cos", "linear"] = "cos",
13 Args: 16 annealing: Literal["cos", "half_cos", "linear"] = "cos",
14 optimizer ([`~torch.optim.Optimizer`]): 17 min_lr: int = 0.04,
15 The optimizer for which to schedule the learning rate. 18 mid_point: int = 0.3,
16 num_training_steps (`int`): 19 last_epoch: int = -1
17 The total number of training steps. 20):
18 last_epoch (`int`, *optional*, defaults to -1):
19 The index of the last epoch when resuming training.
20 Return:
21 `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
22 """
23
24 def lr_lambda(current_step: int): 21 def lr_lambda(current_step: int):
25 thresh_up = int(num_training_steps * min(mid_point, 0.5)) 22 thresh_up = int(num_training_steps * min(mid_point, 0.5))
26 23
27 if current_step < thresh_up: 24 if current_step < thresh_up:
28 return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr) 25 progress = float(current_step) / float(max(1, thresh_up))
26
27 if warmup == "linear":
28 return min_lr + progress * (1 - min_lr)
29
30 return min_lr + 0.5 * (1.0 + math.cos(math.pi * (1 + progress)))
29 31
30 if annealing == "linear": 32 if annealing == "linear":
31 thresh_down = thresh_up * 2 33 thresh_down = thresh_up * 2
32 34
33 if current_step < thresh_down: 35 if current_step < thresh_down:
34 return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr) 36 progress = float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up))
37 return min_lr + progress * (1 - min_lr)
35 38
36 progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down)) 39 progress = float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down))
37 return max(0.0, progress) * min_lr 40 return progress * min_lr
38 41
39 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up)) 42 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up))
40 43
41 if annealing == "half_cos": 44 if annealing == "half_cos":
42 return max(0.0, 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))) 45 return 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress))
43 46
44 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) 47 return 0.5 * (1.0 + math.cos(math.pi * progress))
45 48
46 return LambdaLR(optimizer, lr_lambda, last_epoch) 49 return LambdaLR(optimizer, lr_lambda, last_epoch)