diff options
author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /training/optimization.py | |
parent | Fix LoRA training with DAdan (diff) | |
download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip |
Update
Diffstat (limited to 'training/optimization.py')
-rw-r--r-- | training/optimization.py | 38 |
1 files changed, 28 insertions, 10 deletions
diff --git a/training/optimization.py b/training/optimization.py index d22a900..55531bf 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -5,7 +5,10 @@ from functools import partial | |||
5 | import torch | 5 | import torch |
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import ( |
9 | get_scheduler as get_scheduler_, | ||
10 | get_cosine_with_hard_restarts_schedule_with_warmup, | ||
11 | ) | ||
9 | from transformers.optimization import get_adafactor_schedule | 12 | from transformers.optimization import get_adafactor_schedule |
10 | 13 | ||
11 | 14 | ||
@@ -52,7 +55,7 @@ def get_one_cycle_schedule( | |||
52 | annealing_exp: int = 1, | 55 | annealing_exp: int = 1, |
53 | min_lr: float = 0.04, | 56 | min_lr: float = 0.04, |
54 | mid_point: float = 0.3, | 57 | mid_point: float = 0.3, |
55 | last_epoch: int = -1 | 58 | last_epoch: int = -1, |
56 | ): | 59 | ): |
57 | if warmup == "linear": | 60 | if warmup == "linear": |
58 | warmup_func = warmup_linear | 61 | warmup_func = warmup_linear |
@@ -83,12 +86,16 @@ def get_one_cycle_schedule( | |||
83 | 86 | ||
84 | def lr_lambda(current_step: int): | 87 | def lr_lambda(current_step: int): |
85 | phase = [p for p in phases if current_step >= p.step_min][-1] | 88 | phase = [p for p in phases if current_step >= p.step_min][-1] |
86 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) | 89 | return phase.min + phase.func( |
90 | (current_step - phase.step_min) / (phase.step_max - phase.step_min) | ||
91 | ) * (phase.max - phase.min) | ||
87 | 92 | ||
88 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 93 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
89 | 94 | ||
90 | 95 | ||
91 | def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): | 96 | def get_exponential_growing_schedule( |
97 | optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1 | ||
98 | ): | ||
92 | def lr_lambda(base_lr: float, current_step: int): | 99 | def lr_lambda(base_lr: float, current_step: int): |
93 | return (end_lr / base_lr) ** (current_step / num_training_steps) | 100 | return (end_lr / base_lr) ** (current_step / num_training_steps) |
94 | 101 | ||
@@ -132,7 +139,14 @@ def get_scheduler( | |||
132 | ) | 139 | ) |
133 | elif id == "exponential_growth": | 140 | elif id == "exponential_growth": |
134 | if cycles is None: | 141 | if cycles is None: |
135 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 142 | cycles = math.ceil( |
143 | math.sqrt( | ||
144 | ( | ||
145 | (num_training_steps - num_warmup_steps) | ||
146 | / num_training_steps_per_epoch | ||
147 | ) | ||
148 | ) | ||
149 | ) | ||
136 | 150 | ||
137 | lr_scheduler = get_exponential_growing_schedule( | 151 | lr_scheduler = get_exponential_growing_schedule( |
138 | optimizer=optimizer, | 152 | optimizer=optimizer, |
@@ -141,7 +155,14 @@ def get_scheduler( | |||
141 | ) | 155 | ) |
142 | elif id == "cosine_with_restarts": | 156 | elif id == "cosine_with_restarts": |
143 | if cycles is None: | 157 | if cycles is None: |
144 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 158 | cycles = math.ceil( |
159 | math.sqrt( | ||
160 | ( | ||
161 | (num_training_steps - num_warmup_steps) | ||
162 | / num_training_steps_per_epoch | ||
163 | ) | ||
164 | ) | ||
165 | ) | ||
145 | 166 | ||
146 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 167 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
147 | optimizer=optimizer, | 168 | optimizer=optimizer, |
@@ -150,10 +171,7 @@ def get_scheduler( | |||
150 | num_cycles=cycles, | 171 | num_cycles=cycles, |
151 | ) | 172 | ) |
152 | elif id == "adafactor": | 173 | elif id == "adafactor": |
153 | lr_scheduler = get_adafactor_schedule( | 174 | lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr) |
154 | optimizer, | ||
155 | initial_lr=min_lr | ||
156 | ) | ||
157 | else: | 175 | else: |
158 | lr_scheduler = get_scheduler_( | 176 | lr_scheduler = get_scheduler_( |
159 | id, | 177 | id, |