From 4f724ca8015771c55ab9f382ebec5fd8b3309eb2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 20:13:03 +0200 Subject: Added Prodigy optimizer --- training/functional.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 34a701b..cc079ef 100644 --- a/training/functional.py +++ b/training/functional.py @@ -525,6 +525,7 @@ def train_loop( on_checkpoint = callbacks.on_checkpoint isDadaptation = False + isProdigy = False try: import dadaptation @@ -535,6 +536,13 @@ def train_loop( except ImportError: pass + try: + import prodigyopt + + isProdigy = isinstance(optimizer.optimizer, prodigyopt.Prodigy) + except ImportError: + pass + num_training_steps += global_step_offset global_step += global_step_offset @@ -582,7 +590,7 @@ def train_loop( lr = lr.item() label = group_labels[i] if i < len(group_labels) else f"{i}" logs[f"lr/{label}"] = lr - if isDadaptation: + if isDadaptation or isProdigy: lr = ( optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] -- cgit v1.2.3-54-g00ecf