summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/optimization.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py
new file mode 100644
index 0000000..012beed
--- /dev/null
+++ b/training/optimization.py
@@ -0,0 +1,42 @@
1import math
2from torch.optim.lr_scheduler import LambdaLR
3
4from diffusers.utils import logging
5
6logger = logging.get_logger(__name__)
7
8
9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.42, last_epoch=-1):
10 """
11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
13 Args:
14 optimizer ([`~torch.optim.Optimizer`]):
15 The optimizer for which to schedule the learning rate.
16 num_training_steps (`int`):
17 The total number of training steps.
18 last_epoch (`int`, *optional*, defaults to -1):
19 The index of the last epoch when resuming training.
20 Return:
21 `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
22 """
23
24 def lr_lambda(current_step: int):
25 thresh_up = int(num_training_steps * min(mid_point, 0.5))
26
27 if current_step < thresh_up:
28 return min_lr + float(current_step) / float(max(1, thresh_up)) * (1 - min_lr)
29
30 if annealing == "linear":
31 thresh_down = thresh_up * 2
32
33 if current_step < thresh_down:
34 return min_lr + float(thresh_down - current_step) / float(max(1, thresh_down - thresh_up)) * (1 - min_lr)
35
36 return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - thresh_down))) * min_lr
37 else:
38 progress = float(current_step - thresh_up) / float(max(1, num_training_steps - thresh_up))
39
40 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
41
42 return LambdaLR(optimizer, lr_lambda, last_epoch)