diff options
author | Volpeon <git@volpeon.ink> | 2023-04-01 12:35:43 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-01 12:35:43 +0200 |
commit | 01eee0cb24f52ca78761b78917959e1c247eae94 (patch) | |
tree | 914c0d3f5b888a4c344b30a861639c8e3d5259dd /training/optimization.py | |
parent | Update (diff) | |
download | textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.gz textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.bz2 textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.zip |
Add support for Adafactor, add TI initializer noise
Diffstat (limited to 'training/optimization.py')
-rw-r--r-- | training/optimization.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py index 59ca950..53d0a6d 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -6,6 +6,7 @@ import torch | |||
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
9 | import transformers | ||
9 | 10 | ||
10 | 11 | ||
11 | class OneCyclePhase(NamedTuple): | 12 | class OneCyclePhase(NamedTuple): |
@@ -148,6 +149,8 @@ def get_scheduler( | |||
148 | num_training_steps=num_training_steps, | 149 | num_training_steps=num_training_steps, |
149 | num_cycles=cycles, | 150 | num_cycles=cycles, |
150 | ) | 151 | ) |
152 | elif id == "adafactor": | ||
153 | lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) | ||
151 | else: | 154 | else: |
152 | lr_scheduler = get_scheduler_( | 155 | lr_scheduler = get_scheduler_( |
153 | id, | 156 | id, |