summaryrefslogtreecommitdiffstats
path: root/training/optimization.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 15:54:40 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 15:54:40 +0200
commita551a9ac2edd1dc59828749a5e5d73a65b3c9ce7 (patch)
tree7ccca7f3a70b2b34706ddb849e37924aa6ee88e9 /training/optimization.py
parentAdd support for Adafactor, add TI initializer noise (diff)
downloadtextual-inversion-diff-a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7.tar.gz
textual-inversion-diff-a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7.tar.bz2
textual-inversion-diff-a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7.zip
Update
Diffstat (limited to 'training/optimization.py')
-rw-r--r--training/optimization.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/training/optimization.py b/training/optimization.py
index 53d0a6d..d22a900 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,7 +6,7 @@ import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9import transformers 9from transformers.optimization import get_adafactor_schedule
10 10
11 11
12class OneCyclePhase(NamedTuple): 12class OneCyclePhase(NamedTuple):
@@ -150,7 +150,10 @@ def get_scheduler(
150 num_cycles=cycles, 150 num_cycles=cycles,
151 ) 151 )
152 elif id == "adafactor": 152 elif id == "adafactor":
153 lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) 153 lr_scheduler = get_adafactor_schedule(
154 optimizer,
155 initial_lr=min_lr
156 )
154 else: 157 else:
155 lr_scheduler = get_scheduler_( 158 lr_scheduler = get_scheduler_(
156 id, 159 id,