diff options
Diffstat (limited to 'training/optimization.py')
-rw-r--r-- | training/optimization.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/training/optimization.py b/training/optimization.py index 53d0a6d..d22a900 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -6,7 +6,7 @@ import torch | |||
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
9 | import transformers | 9 | from transformers.optimization import get_adafactor_schedule |
10 | 10 | ||
11 | 11 | ||
12 | class OneCyclePhase(NamedTuple): | 12 | class OneCyclePhase(NamedTuple): |
@@ -150,7 +150,10 @@ def get_scheduler( | |||
150 | num_cycles=cycles, | 150 | num_cycles=cycles, |
151 | ) | 151 | ) |
152 | elif id == "adafactor": | 152 | elif id == "adafactor": |
153 | lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) | 153 | lr_scheduler = get_adafactor_schedule( |
154 | optimizer, | ||
155 | initial_lr=min_lr | ||
156 | ) | ||
154 | else: | 157 | else: |
155 | lr_scheduler = get_scheduler_( | 158 | lr_scheduler = get_scheduler_( |
156 | id, | 159 | id, |