diff options
Diffstat (limited to 'training/optimization.py')
-rw-r--r-- | training/optimization.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py index dd84f9c..5db7794 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -5,6 +5,8 @@ from functools import partial | |||
5 | import torch | 5 | import torch |
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | ||
9 | |||
8 | 10 | ||
9 | class OneCyclePhase(NamedTuple): | 11 | class OneCyclePhase(NamedTuple): |
10 | step_min: int | 12 | step_min: int |
@@ -83,3 +85,54 @@ def get_one_cycle_schedule( | |||
83 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) | 85 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) |
84 | 86 | ||
85 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 87 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
88 | |||
89 | |||
90 | def get_scheduler( | ||
91 | id: str, | ||
92 | optimizer: torch.optim.Optimizer, | ||
93 | num_training_steps_per_epoch: int, | ||
94 | gradient_accumulation_steps: int, | ||
95 | min_lr: float = 0.04, | ||
96 | warmup_func: str = "cos", | ||
97 | annealing_func: str = "cos", | ||
98 | warmup_exp: int = 1, | ||
99 | annealing_exp: int = 1, | ||
100 | cycles: int = 1, | ||
101 | train_epochs: int = 100, | ||
102 | warmup_epochs: int = 10, | ||
103 | ): | ||
104 | num_training_steps_per_epoch = math.ceil( | ||
105 | num_training_steps_per_epoch / gradient_accumulation_steps | ||
106 | ) * gradient_accumulation_steps | ||
107 | num_training_steps = train_epochs * num_training_steps_per_epoch | ||
108 | num_warmup_steps = warmup_epochs * num_training_steps_per_epoch | ||
109 | |||
110 | if id == "one_cycle": | ||
111 | lr_scheduler = get_one_cycle_schedule( | ||
112 | optimizer=optimizer, | ||
113 | num_training_steps=num_training_steps, | ||
114 | warmup=warmup_func, | ||
115 | annealing=annealing_func, | ||
116 | warmup_exp=warmup_exp, | ||
117 | annealing_exp=annealing_exp, | ||
118 | min_lr=min_lr, | ||
119 | ) | ||
120 | elif id == "cosine_with_restarts": | ||
121 | if cycles is None: | ||
122 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | ||
123 | |||
124 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
125 | optimizer=optimizer, | ||
126 | num_warmup_steps=num_warmup_steps, | ||
127 | num_training_steps=num_training_steps, | ||
128 | num_cycles=cycles, | ||
129 | ) | ||
130 | else: | ||
131 | lr_scheduler = get_scheduler_( | ||
132 | id, | ||
133 | optimizer=optimizer, | ||
134 | num_warmup_steps=num_warmup_steps, | ||
135 | num_training_steps=num_training_steps, | ||
136 | ) | ||
137 | |||
138 | return lr_scheduler | ||