diff options
Diffstat (limited to 'training/optimization.py')
| -rw-r--r-- | training/optimization.py | 7 | 
1 files changed, 5 insertions, 2 deletions
| diff --git a/training/optimization.py b/training/optimization.py index 53d0a6d..d22a900 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -6,7 +6,7 @@ import torch | |||
| 6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR | 
| 7 | 7 | ||
| 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 
| 9 | import transformers | 9 | from transformers.optimization import get_adafactor_schedule | 
| 10 | 10 | ||
| 11 | 11 | ||
| 12 | class OneCyclePhase(NamedTuple): | 12 | class OneCyclePhase(NamedTuple): | 
| @@ -150,7 +150,10 @@ def get_scheduler( | |||
| 150 | num_cycles=cycles, | 150 | num_cycles=cycles, | 
| 151 | ) | 151 | ) | 
| 152 | elif id == "adafactor": | 152 | elif id == "adafactor": | 
| 153 | lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) | 153 | lr_scheduler = get_adafactor_schedule( | 
| 154 | optimizer, | ||
| 155 | initial_lr=min_lr | ||
| 156 | ) | ||
| 154 | else: | 157 | else: | 
| 155 | lr_scheduler = get_scheduler_( | 158 | lr_scheduler = get_scheduler_( | 
| 156 | id, | 159 | id, | 
