From a551a9ac2edd1dc59828749a5e5d73a65b3c9ce7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 15:54:40 +0200 Subject: Update --- training/optimization.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'training/optimization.py') diff --git a/training/optimization.py b/training/optimization.py index 53d0a6d..d22a900 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -6,7 +6,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup -import transformers +from transformers.optimization import get_adafactor_schedule class OneCyclePhase(NamedTuple): @@ -150,7 +150,10 @@ def get_scheduler( num_cycles=cycles, ) elif id == "adafactor": - lr_scheduler = transformers.optimization.AdafactorSchedule(optimizer, min_lr) + lr_scheduler = get_adafactor_schedule( + optimizer, + initial_lr=min_lr + ) else: lr_scheduler = get_scheduler_( id, -- cgit v1.2.3-70-g09d2