From 07c99baaf18f2b8e98b5f7d9cce2088600e63a7f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Mar 2023 17:13:53 +0100 Subject: Log DAdam/DAdan d --- training/functional.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 43ee356..77f056e 100644 --- a/training/functional.py +++ b/training/functional.py @@ -424,6 +424,15 @@ def train_loop( on_sample = callbacks.on_sample on_checkpoint = callbacks.on_checkpoint + isDadaptation = False + + try: + import dadaptation + + isDadaptation = isinstance(optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) + except ImportError: + pass + try: for epoch in range(num_epochs): if accelerator.is_main_process: @@ -461,6 +470,11 @@ def train_loop( "train/cur_acc": acc.item(), "lr": lr_scheduler.get_last_lr()[0], } + if isDadaptation: + logs["lr/d*lr"] = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) logs.update(on_log()) local_progress_bar.set_postfix(**logs) -- cgit v1.2.3-70-g09d2