diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py index 43ee356..77f056e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -424,6 +424,15 @@ def train_loop( | |||
| 424 | on_sample = callbacks.on_sample | 424 | on_sample = callbacks.on_sample |
| 425 | on_checkpoint = callbacks.on_checkpoint | 425 | on_checkpoint = callbacks.on_checkpoint |
| 426 | 426 | ||
| 427 | isDadaptation = False | ||
| 428 | |||
| 429 | try: | ||
| 430 | import dadaptation | ||
| 431 | |||
| 432 | isDadaptation = isinstance(optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) | ||
| 433 | except ImportError: | ||
| 434 | pass | ||
| 435 | |||
| 427 | try: | 436 | try: |
| 428 | for epoch in range(num_epochs): | 437 | for epoch in range(num_epochs): |
| 429 | if accelerator.is_main_process: | 438 | if accelerator.is_main_process: |
| @@ -461,6 +470,11 @@ def train_loop( | |||
| 461 | "train/cur_acc": acc.item(), | 470 | "train/cur_acc": acc.item(), |
| 462 | "lr": lr_scheduler.get_last_lr()[0], | 471 | "lr": lr_scheduler.get_last_lr()[0], |
| 463 | } | 472 | } |
| 473 | if isDadaptation: | ||
| 474 | logs["lr/d*lr"] = ( | ||
| 475 | optimizer.param_groups[0]["d"] * | ||
| 476 | optimizer.param_groups[0]["lr"] | ||
| 477 | ) | ||
| 464 | logs.update(on_log()) | 478 | logs.update(on_log()) |
| 465 | 479 | ||
| 466 | local_progress_bar.set_postfix(**logs) | 480 | local_progress_bar.set_postfix(**logs) |
