summaryrefslogtreecommitdiffstats
path: root/training/optimization.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 22:25:30 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 22:25:30 +0100
commit3e7fbb7dce321435bbbb81361debfbc499bf9231 (patch)
treee7d5cefd2eda9755ab58861862f1978c13386f0d /training/optimization.py
parentMore modularization (diff)
downloadtextual-inversion-diff-3e7fbb7dce321435bbbb81361debfbc499bf9231.tar.gz
textual-inversion-diff-3e7fbb7dce321435bbbb81361debfbc499bf9231.tar.bz2
textual-inversion-diff-3e7fbb7dce321435bbbb81361debfbc499bf9231.zip
Reverted modularization mostly
Diffstat (limited to 'training/optimization.py')
-rw-r--r--training/optimization.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py
index dd84f9c..5db7794 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -5,6 +5,8 @@ from functools import partial
5import torch 5import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9
8 10
9class OneCyclePhase(NamedTuple): 11class OneCyclePhase(NamedTuple):
10 step_min: int 12 step_min: int
@@ -83,3 +85,54 @@ def get_one_cycle_schedule(
83 return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) 85 return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min)
84 86
85 return LambdaLR(optimizer, lr_lambda, last_epoch) 87 return LambdaLR(optimizer, lr_lambda, last_epoch)
88
89
90def get_scheduler(
91 id: str,
92 optimizer: torch.optim.Optimizer,
93 num_training_steps_per_epoch: int,
94 gradient_accumulation_steps: int,
95 min_lr: float = 0.04,
96 warmup_func: str = "cos",
97 annealing_func: str = "cos",
98 warmup_exp: int = 1,
99 annealing_exp: int = 1,
100 cycles: int = 1,
101 train_epochs: int = 100,
102 warmup_epochs: int = 10,
103):
104 num_training_steps_per_epoch = math.ceil(
105 num_training_steps_per_epoch / gradient_accumulation_steps
106 ) * gradient_accumulation_steps
107 num_training_steps = train_epochs * num_training_steps_per_epoch
108 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
109
110 if id == "one_cycle":
111 lr_scheduler = get_one_cycle_schedule(
112 optimizer=optimizer,
113 num_training_steps=num_training_steps,
114 warmup=warmup_func,
115 annealing=annealing_func,
116 warmup_exp=warmup_exp,
117 annealing_exp=annealing_exp,
118 min_lr=min_lr,
119 )
120 elif id == "cosine_with_restarts":
121 if cycles is None:
122 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
123
124 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
125 optimizer=optimizer,
126 num_warmup_steps=num_warmup_steps,
127 num_training_steps=num_training_steps,
128 num_cycles=cycles,
129 )
130 else:
131 lr_scheduler = get_scheduler_(
132 id,
133 optimizer=optimizer,
134 num_warmup_steps=num_warmup_steps,
135 num_training_steps=num_training_steps,
136 )
137
138 return lr_scheduler