summaryrefslogtreecommitdiffstats
path: root/training/optimization.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 12:35:43 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 12:35:43 +0200
commit01eee0cb24f52ca78761b78917959e1c247eae94 (patch)
tree914c0d3f5b888a4c344b30a861639c8e3d5259dd /training/optimization.py
parentUpdate (diff)
downloadtextual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.gz
textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.bz2
textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.zip
Add support for Adafactor, add TI initializer noise
Diffstat (limited to 'training/optimization.py')
-rw-r--r--training/optimization.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py
index 59ca950..53d0a6d 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,6 +6,7 @@ import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9import transformers
9 10
10 11
11class OneCyclePhase(NamedTuple): 12class OneCyclePhase(NamedTuple):
@@ -148,6 +149,8 @@ def get_scheduler(
148 num_training_steps=num_training_steps, 149 num_training_steps=num_training_steps,
149 num_cycles=cycles, 150 num_cycles=cycles,
150 ) 151 )
152 elif id == "adafactor":
153 lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr)
151 else: 154 else:
152 lr_scheduler = get_scheduler_( 155 lr_scheduler = get_scheduler_(
153 id, 156 id,