diff options
Diffstat (limited to 'training/optimization.py')
| -rw-r--r-- | training/optimization.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/training/optimization.py b/training/optimization.py index 59ca950..53d0a6d 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -6,6 +6,7 @@ import torch | |||
| 6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
| 7 | 7 | ||
| 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
| 9 | import transformers | ||
| 9 | 10 | ||
| 10 | 11 | ||
| 11 | class OneCyclePhase(NamedTuple): | 12 | class OneCyclePhase(NamedTuple): |
| @@ -148,6 +149,8 @@ def get_scheduler( | |||
| 148 | num_training_steps=num_training_steps, | 149 | num_training_steps=num_training_steps, |
| 149 | num_cycles=cycles, | 150 | num_cycles=cycles, |
| 150 | ) | 151 | ) |
| 152 | elif id == "adafactor": | ||
| 153 | lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) | ||
| 151 | else: | 154 | else: |
| 152 | lr_scheduler = get_scheduler_( | 155 | lr_scheduler = get_scheduler_( |
| 153 | id, | 156 | id, |
