diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py index 34a701b..cc079ef 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -525,6 +525,7 @@ def train_loop( | |||
525 | on_checkpoint = callbacks.on_checkpoint | 525 | on_checkpoint = callbacks.on_checkpoint |
526 | 526 | ||
527 | isDadaptation = False | 527 | isDadaptation = False |
528 | isProdigy = False | ||
528 | 529 | ||
529 | try: | 530 | try: |
530 | import dadaptation | 531 | import dadaptation |
@@ -535,6 +536,13 @@ def train_loop( | |||
535 | except ImportError: | 536 | except ImportError: |
536 | pass | 537 | pass |
537 | 538 | ||
539 | try: | ||
540 | import prodigyopt | ||
541 | |||
542 | isProdigy = isinstance(optimizer.optimizer, prodigyopt.Prodigy) | ||
543 | except ImportError: | ||
544 | pass | ||
545 | |||
538 | num_training_steps += global_step_offset | 546 | num_training_steps += global_step_offset |
539 | global_step += global_step_offset | 547 | global_step += global_step_offset |
540 | 548 | ||
@@ -582,7 +590,7 @@ def train_loop( | |||
582 | lr = lr.item() | 590 | lr = lr.item() |
583 | label = group_labels[i] if i < len(group_labels) else f"{i}" | 591 | label = group_labels[i] if i < len(group_labels) else f"{i}" |
584 | logs[f"lr/{label}"] = lr | 592 | logs[f"lr/{label}"] = lr |
585 | if isDadaptation: | 593 | if isDadaptation or isProdigy: |
586 | lr = ( | 594 | lr = ( |
587 | optimizer.param_groups[i]["d"] | 595 | optimizer.param_groups[i]["d"] |
588 | * optimizer.param_groups[i]["lr"] | 596 | * optimizer.param_groups[i]["lr"] |