diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py index 43ee356..77f056e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -424,6 +424,15 @@ def train_loop( | |||
424 | on_sample = callbacks.on_sample | 424 | on_sample = callbacks.on_sample |
425 | on_checkpoint = callbacks.on_checkpoint | 425 | on_checkpoint = callbacks.on_checkpoint |
426 | 426 | ||
427 | isDadaptation = False | ||
428 | |||
429 | try: | ||
430 | import dadaptation | ||
431 | |||
432 | isDadaptation = isinstance(optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) | ||
433 | except ImportError: | ||
434 | pass | ||
435 | |||
427 | try: | 436 | try: |
428 | for epoch in range(num_epochs): | 437 | for epoch in range(num_epochs): |
429 | if accelerator.is_main_process: | 438 | if accelerator.is_main_process: |
@@ -461,6 +470,11 @@ def train_loop( | |||
461 | "train/cur_acc": acc.item(), | 470 | "train/cur_acc": acc.item(), |
462 | "lr": lr_scheduler.get_last_lr()[0], | 471 | "lr": lr_scheduler.get_last_lr()[0], |
463 | } | 472 | } |
473 | if isDadaptation: | ||
474 | logs["lr/d*lr"] = ( | ||
475 | optimizer.param_groups[0]["d"] * | ||
476 | optimizer.param_groups[0]["lr"] | ||
477 | ) | ||
464 | logs.update(on_log()) | 478 | logs.update(on_log()) |
465 | 479 | ||
466 | local_progress_bar.set_postfix(**logs) | 480 | local_progress_bar.set_postfix(**logs) |