diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-13 22:25:30 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-13 22:25:30 +0100 |
| commit | 3e7fbb7dce321435bbbb81361debfbc499bf9231 (patch) | |
| tree | e7d5cefd2eda9755ab58861862f1978c13386f0d /training/optimization.py | |
| parent | More modularization (diff) | |
| download | textual-inversion-diff-3e7fbb7dce321435bbbb81361debfbc499bf9231.tar.gz textual-inversion-diff-3e7fbb7dce321435bbbb81361debfbc499bf9231.tar.bz2 textual-inversion-diff-3e7fbb7dce321435bbbb81361debfbc499bf9231.zip | |
Reverted modularization mostly
Diffstat (limited to 'training/optimization.py')
| -rw-r--r-- | training/optimization.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py index dd84f9c..5db7794 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -5,6 +5,8 @@ from functools import partial | |||
| 5 | import torch | 5 | import torch |
| 6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
| 7 | 7 | ||
| 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | ||
| 9 | |||
| 8 | 10 | ||
| 9 | class OneCyclePhase(NamedTuple): | 11 | class OneCyclePhase(NamedTuple): |
| 10 | step_min: int | 12 | step_min: int |
| @@ -83,3 +85,54 @@ def get_one_cycle_schedule( | |||
| 83 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) | 85 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) |
| 84 | 86 | ||
| 85 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 87 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| 88 | |||
| 89 | |||
| 90 | def get_scheduler( | ||
| 91 | id: str, | ||
| 92 | optimizer: torch.optim.Optimizer, | ||
| 93 | num_training_steps_per_epoch: int, | ||
| 94 | gradient_accumulation_steps: int, | ||
| 95 | min_lr: float = 0.04, | ||
| 96 | warmup_func: str = "cos", | ||
| 97 | annealing_func: str = "cos", | ||
| 98 | warmup_exp: int = 1, | ||
| 99 | annealing_exp: int = 1, | ||
| 100 | cycles: int = 1, | ||
| 101 | train_epochs: int = 100, | ||
| 102 | warmup_epochs: int = 10, | ||
| 103 | ): | ||
| 104 | num_training_steps_per_epoch = math.ceil( | ||
| 105 | num_training_steps_per_epoch / gradient_accumulation_steps | ||
| 106 | ) * gradient_accumulation_steps | ||
| 107 | num_training_steps = train_epochs * num_training_steps_per_epoch | ||
| 108 | num_warmup_steps = warmup_epochs * num_training_steps_per_epoch | ||
| 109 | |||
| 110 | if id == "one_cycle": | ||
| 111 | lr_scheduler = get_one_cycle_schedule( | ||
| 112 | optimizer=optimizer, | ||
| 113 | num_training_steps=num_training_steps, | ||
| 114 | warmup=warmup_func, | ||
| 115 | annealing=annealing_func, | ||
| 116 | warmup_exp=warmup_exp, | ||
| 117 | annealing_exp=annealing_exp, | ||
| 118 | min_lr=min_lr, | ||
| 119 | ) | ||
| 120 | elif id == "cosine_with_restarts": | ||
| 121 | if cycles is None: | ||
| 122 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | ||
| 123 | |||
| 124 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
| 125 | optimizer=optimizer, | ||
| 126 | num_warmup_steps=num_warmup_steps, | ||
| 127 | num_training_steps=num_training_steps, | ||
| 128 | num_cycles=cycles, | ||
| 129 | ) | ||
| 130 | else: | ||
| 131 | lr_scheduler = get_scheduler_( | ||
| 132 | id, | ||
| 133 | optimizer=optimizer, | ||
| 134 | num_warmup_steps=num_warmup_steps, | ||
| 135 | num_training_steps=num_training_steps, | ||
| 136 | ) | ||
| 137 | |||
| 138 | return lr_scheduler | ||
