From 07c99baaf18f2b8e98b5f7d9cce2088600e63a7f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Mar 2023 17:13:53 +0100 Subject: Log DAdam/DAdan d --- train_dreambooth.py | 4 ++-- train_lora.py | 4 ++-- train_ti.py | 4 ++-- training/functional.py | 14 ++++++++++++++ 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index b706d07..f8f6e84 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -287,8 +287,8 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="adam", - help='Optimizer to use ["adam", "adam8bit"]' + default="dadan", + help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' ) parser.add_argument( "--adam_beta1", diff --git a/train_lora.py b/train_lora.py index ce8fb50..787f271 100644 --- a/train_lora.py +++ b/train_lora.py @@ -245,8 +245,8 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="adam", - help='Optimizer to use ["adam", "adam8bit"]' + default="dadan", + help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' ) parser.add_argument( "--adam_beta1", diff --git a/train_ti.py b/train_ti.py index ee65b44..036c288 100644 --- a/train_ti.py +++ b/train_ti.py @@ -293,8 +293,8 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="adam", - help='Optimizer to use ["adam", "adam8bit"]' + default="dadan", + help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' ) parser.add_argument( "--adam_beta1", diff --git a/training/functional.py b/training/functional.py index 43ee356..77f056e 100644 --- a/training/functional.py +++ b/training/functional.py @@ -424,6 +424,15 @@ def train_loop( on_sample = callbacks.on_sample on_checkpoint = callbacks.on_checkpoint + isDadaptation = False + + try: + import dadaptation + + isDadaptation = isinstance(optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) + except ImportError: + pass + try: for epoch in range(num_epochs): if accelerator.is_main_process: @@ -461,6 +470,11 @@ def train_loop( "train/cur_acc": acc.item(), "lr": lr_scheduler.get_last_lr()[0], } + if isDadaptation: + logs["lr/d*lr"] = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) logs.update(on_log()) local_progress_bar.set_postfix(**logs) -- cgit v1.2.3-70-g09d2