summaryrefslogtreecommitdiffstats
path: root/training/optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/optimization.py')
-rw-r--r--training/optimization.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/training/optimization.py b/training/optimization.py
index 53d0a6d..d22a900 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,7 +6,7 @@ import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9import transformers 9from transformers.optimization import get_adafactor_schedule
10 10
11 11
12class OneCyclePhase(NamedTuple): 12class OneCyclePhase(NamedTuple):
@@ -150,7 +150,10 @@ def get_scheduler(
150 num_cycles=cycles, 150 num_cycles=cycles,
151 ) 151 )
152 elif id == "adafactor": 152 elif id == "adafactor":
153 lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) 153 lr_scheduler = get_adafactor_schedule(
154 optimizer,
155 initial_lr=min_lr
156 )
154 else: 157 else:
155 lr_scheduler = get_scheduler_( 158 lr_scheduler = get_scheduler_(
156 id, 159 id,