summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py
index 34a701b..cc079ef 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -525,6 +525,7 @@ def train_loop(
525 on_checkpoint = callbacks.on_checkpoint 525 on_checkpoint = callbacks.on_checkpoint
526 526
527 isDadaptation = False 527 isDadaptation = False
528 isProdigy = False
528 529
529 try: 530 try:
530 import dadaptation 531 import dadaptation
@@ -535,6 +536,13 @@ def train_loop(
535 except ImportError: 536 except ImportError:
536 pass 537 pass
537 538
539 try:
540 import prodigyopt
541
542 isProdigy = isinstance(optimizer.optimizer, prodigyopt.Prodigy)
543 except ImportError:
544 pass
545
538 num_training_steps += global_step_offset 546 num_training_steps += global_step_offset
539 global_step += global_step_offset 547 global_step += global_step_offset
540 548
@@ -582,7 +590,7 @@ def train_loop(
582 lr = lr.item() 590 lr = lr.item()
583 label = group_labels[i] if i < len(group_labels) else f"{i}" 591 label = group_labels[i] if i < len(group_labels) else f"{i}"
584 logs[f"lr/{label}"] = lr 592 logs[f"lr/{label}"] = lr
585 if isDadaptation: 593 if isDadaptation or isProdigy:
586 lr = ( 594 lr = (
587 optimizer.param_groups[i]["d"] 595 optimizer.param_groups[i]["d"]
588 * optimizer.param_groups[i]["lr"] 596 * optimizer.param_groups[i]["lr"]