diff options
author | Volpeon <git@volpeon.ink> | 2023-04-01 15:54:40 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-01 15:54:40 +0200 |
commit | a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7 (patch) | |
tree | 7ccca7f3a70b2b34706ddb849e37924aa6ee88e9 /training/optimization.py | |
parent | Add support for Adafactor, add TI initializer noise (diff) | |
download | textual-inversion-diff-a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7.tar.gz textual-inversion-diff-a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7.tar.bz2 textual-inversion-diff-a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7.zip |
Update
Diffstat (limited to 'training/optimization.py')
-rw-r--r-- | training/optimization.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/training/optimization.py b/training/optimization.py index 53d0a6d..d22a900 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -6,7 +6,7 @@ import torch | |||
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
9 | import transformers | 9 | from transformers.optimization import get_adafactor_schedule |
10 | 10 | ||
11 | 11 | ||
12 | class OneCyclePhase(NamedTuple): | 12 | class OneCyclePhase(NamedTuple): |
@@ -150,7 +150,10 @@ def get_scheduler( | |||
150 | num_cycles=cycles, | 150 | num_cycles=cycles, |
151 | ) | 151 | ) |
152 | elif id == "adafactor": | 152 | elif id == "adafactor": |
153 | lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) | 153 | lr_scheduler = get_adafactor_schedule( |
154 | optimizer, | ||
155 | initial_lr=min_lr | ||
156 | ) | ||
154 | else: | 157 | else: |
155 | lr_scheduler = get_scheduler_( | 158 | lr_scheduler = get_scheduler_( |
156 | id, | 159 | id, |