summaryrefslogtreecommitdiffstats
path: root/training/optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/optimization.py')
-rw-r--r--training/optimization.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py
index 59ca950..53d0a6d 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,6 +6,7 @@ import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9import transformers
9 10
10 11
11class OneCyclePhase(NamedTuple): 12class OneCyclePhase(NamedTuple):
@@ -148,6 +149,8 @@ def get_scheduler(
148 num_training_steps=num_training_steps, 149 num_training_steps=num_training_steps,
149 num_cycles=cycles, 150 num_cycles=cycles,
150 ) 151 )
152 elif id == "adafactor":
153 lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr)
151 else: 154 else:
152 lr_scheduler = get_scheduler_( 155 lr_scheduler = get_scheduler_(
153 id, 156 id,