diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
| commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
| tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /training/optimization.py | |
| parent | Fix LoRA training with DAdan (diff) | |
| download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip | |
Update
Diffstat (limited to 'training/optimization.py')
| -rw-r--r-- | training/optimization.py | 38 |
1 files changed, 28 insertions, 10 deletions
diff --git a/training/optimization.py b/training/optimization.py index d22a900..55531bf 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -5,7 +5,10 @@ from functools import partial | |||
| 5 | import torch | 5 | import torch |
| 6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
| 7 | 7 | ||
| 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import ( |
| 9 | get_scheduler as get_scheduler_, | ||
| 10 | get_cosine_with_hard_restarts_schedule_with_warmup, | ||
| 11 | ) | ||
| 9 | from transformers.optimization import get_adafactor_schedule | 12 | from transformers.optimization import get_adafactor_schedule |
| 10 | 13 | ||
| 11 | 14 | ||
| @@ -52,7 +55,7 @@ def get_one_cycle_schedule( | |||
| 52 | annealing_exp: int = 1, | 55 | annealing_exp: int = 1, |
| 53 | min_lr: float = 0.04, | 56 | min_lr: float = 0.04, |
| 54 | mid_point: float = 0.3, | 57 | mid_point: float = 0.3, |
| 55 | last_epoch: int = -1 | 58 | last_epoch: int = -1, |
| 56 | ): | 59 | ): |
| 57 | if warmup == "linear": | 60 | if warmup == "linear": |
| 58 | warmup_func = warmup_linear | 61 | warmup_func = warmup_linear |
| @@ -83,12 +86,16 @@ def get_one_cycle_schedule( | |||
| 83 | 86 | ||
| 84 | def lr_lambda(current_step: int): | 87 | def lr_lambda(current_step: int): |
| 85 | phase = [p for p in phases if current_step >= p.step_min][-1] | 88 | phase = [p for p in phases if current_step >= p.step_min][-1] |
| 86 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) | 89 | return phase.min + phase.func( |
| 90 | (current_step - phase.step_min) / (phase.step_max - phase.step_min) | ||
| 91 | ) * (phase.max - phase.min) | ||
| 87 | 92 | ||
| 88 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 93 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| 89 | 94 | ||
| 90 | 95 | ||
| 91 | def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): | 96 | def get_exponential_growing_schedule( |
| 97 | optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1 | ||
| 98 | ): | ||
| 92 | def lr_lambda(base_lr: float, current_step: int): | 99 | def lr_lambda(base_lr: float, current_step: int): |
| 93 | return (end_lr / base_lr) ** (current_step / num_training_steps) | 100 | return (end_lr / base_lr) ** (current_step / num_training_steps) |
| 94 | 101 | ||
| @@ -132,7 +139,14 @@ def get_scheduler( | |||
| 132 | ) | 139 | ) |
| 133 | elif id == "exponential_growth": | 140 | elif id == "exponential_growth": |
| 134 | if cycles is None: | 141 | if cycles is None: |
| 135 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 142 | cycles = math.ceil( |
| 143 | math.sqrt( | ||
| 144 | ( | ||
| 145 | (num_training_steps - num_warmup_steps) | ||
| 146 | / num_training_steps_per_epoch | ||
| 147 | ) | ||
| 148 | ) | ||
| 149 | ) | ||
| 136 | 150 | ||
| 137 | lr_scheduler = get_exponential_growing_schedule( | 151 | lr_scheduler = get_exponential_growing_schedule( |
| 138 | optimizer=optimizer, | 152 | optimizer=optimizer, |
| @@ -141,7 +155,14 @@ def get_scheduler( | |||
| 141 | ) | 155 | ) |
| 142 | elif id == "cosine_with_restarts": | 156 | elif id == "cosine_with_restarts": |
| 143 | if cycles is None: | 157 | if cycles is None: |
| 144 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 158 | cycles = math.ceil( |
| 159 | math.sqrt( | ||
| 160 | ( | ||
| 161 | (num_training_steps - num_warmup_steps) | ||
| 162 | / num_training_steps_per_epoch | ||
| 163 | ) | ||
| 164 | ) | ||
| 165 | ) | ||
| 145 | 166 | ||
| 146 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 167 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 147 | optimizer=optimizer, | 168 | optimizer=optimizer, |
| @@ -150,10 +171,7 @@ def get_scheduler( | |||
| 150 | num_cycles=cycles, | 171 | num_cycles=cycles, |
| 151 | ) | 172 | ) |
| 152 | elif id == "adafactor": | 173 | elif id == "adafactor": |
| 153 | lr_scheduler = get_adafactor_schedule( | 174 | lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr) |
| 154 | optimizer, | ||
| 155 | initial_lr=min_lr | ||
| 156 | ) | ||
| 157 | else: | 175 | else: |
| 158 | lr_scheduler = get_scheduler_( | 176 | lr_scheduler = get_scheduler_( |
| 159 | id, | 177 | id, |
