summaryrefslogtreecommitdiffstats
path: root/training/optimization.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
commit8364ce697ddf6117fdd4f7222832d546d63880de (patch)
tree152c99815bbd8b2659d0dabe63c98f63151c97c2 /training/optimization.py
parentFix LoRA training with DAdan (diff)
downloadtextual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip
Update
Diffstat (limited to 'training/optimization.py')
-rw-r--r--training/optimization.py38
1 files changed, 28 insertions, 10 deletions
diff --git a/training/optimization.py b/training/optimization.py
index d22a900..55531bf 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -5,7 +5,10 @@ from functools import partial
5import torch 5import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import (
9 get_scheduler as get_scheduler_,
10 get_cosine_with_hard_restarts_schedule_with_warmup,
11)
9from transformers.optimization import get_adafactor_schedule 12from transformers.optimization import get_adafactor_schedule
10 13
11 14
@@ -52,7 +55,7 @@ def get_one_cycle_schedule(
52 annealing_exp: int = 1, 55 annealing_exp: int = 1,
53 min_lr: float = 0.04, 56 min_lr: float = 0.04,
54 mid_point: float = 0.3, 57 mid_point: float = 0.3,
55 last_epoch: int = -1 58 last_epoch: int = -1,
56): 59):
57 if warmup == "linear": 60 if warmup == "linear":
58 warmup_func = warmup_linear 61 warmup_func = warmup_linear
@@ -83,12 +86,16 @@ def get_one_cycle_schedule(
83 86
84 def lr_lambda(current_step: int): 87 def lr_lambda(current_step: int):
85 phase = [p for p in phases if current_step >= p.step_min][-1] 88 phase = [p for p in phases if current_step >= p.step_min][-1]
86 return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) 89 return phase.min + phase.func(
90 (current_step - phase.step_min) / (phase.step_max - phase.step_min)
91 ) * (phase.max - phase.min)
87 92
88 return LambdaLR(optimizer, lr_lambda, last_epoch) 93 return LambdaLR(optimizer, lr_lambda, last_epoch)
89 94
90 95
91def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): 96def get_exponential_growing_schedule(
97 optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1
98):
92 def lr_lambda(base_lr: float, current_step: int): 99 def lr_lambda(base_lr: float, current_step: int):
93 return (end_lr / base_lr) ** (current_step / num_training_steps) 100 return (end_lr / base_lr) ** (current_step / num_training_steps)
94 101
@@ -132,7 +139,14 @@ def get_scheduler(
132 ) 139 )
133 elif id == "exponential_growth": 140 elif id == "exponential_growth":
134 if cycles is None: 141 if cycles is None:
135 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) 142 cycles = math.ceil(
143 math.sqrt(
144 (
145 (num_training_steps - num_warmup_steps)
146 / num_training_steps_per_epoch
147 )
148 )
149 )
136 150
137 lr_scheduler = get_exponential_growing_schedule( 151 lr_scheduler = get_exponential_growing_schedule(
138 optimizer=optimizer, 152 optimizer=optimizer,
@@ -141,7 +155,14 @@ def get_scheduler(
141 ) 155 )
142 elif id == "cosine_with_restarts": 156 elif id == "cosine_with_restarts":
143 if cycles is None: 157 if cycles is None:
144 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) 158 cycles = math.ceil(
159 math.sqrt(
160 (
161 (num_training_steps - num_warmup_steps)
162 / num_training_steps_per_epoch
163 )
164 )
165 )
145 166
146 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 167 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
147 optimizer=optimizer, 168 optimizer=optimizer,
@@ -150,10 +171,7 @@ def get_scheduler(
150 num_cycles=cycles, 171 num_cycles=cycles,
151 ) 172 )
152 elif id == "adafactor": 173 elif id == "adafactor":
153 lr_scheduler = get_adafactor_schedule( 174 lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr)
154 optimizer,
155 initial_lr=min_lr
156 )
157 else: 175 else:
158 lr_scheduler = get_scheduler_( 176 lr_scheduler = get_scheduler_(
159 id, 177 id,