summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py4
-rw-r--r--train_lora.py4
-rw-r--r--train_ti.py4
-rw-r--r--training/functional.py14
4 files changed, 20 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index b706d07..f8f6e84 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -287,8 +287,8 @@ def parse_args():
287 parser.add_argument( 287 parser.add_argument(
288 "--optimizer", 288 "--optimizer",
289 type=str, 289 type=str,
290 default="adam", 290 default="dadan",
291 help='Optimizer to use ["adam", "adam8bit"]' 291 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
292 ) 292 )
293 parser.add_argument( 293 parser.add_argument(
294 "--adam_beta1", 294 "--adam_beta1",
diff --git a/train_lora.py b/train_lora.py
index ce8fb50..787f271 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -245,8 +245,8 @@ def parse_args():
245 parser.add_argument( 245 parser.add_argument(
246 "--optimizer", 246 "--optimizer",
247 type=str, 247 type=str,
248 default="adam", 248 default="dadan",
249 help='Optimizer to use ["adam", "adam8bit"]' 249 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
250 ) 250 )
251 parser.add_argument( 251 parser.add_argument(
252 "--adam_beta1", 252 "--adam_beta1",
diff --git a/train_ti.py b/train_ti.py
index ee65b44..036c288 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -293,8 +293,8 @@ def parse_args():
293 parser.add_argument( 293 parser.add_argument(
294 "--optimizer", 294 "--optimizer",
295 type=str, 295 type=str,
296 default="adam", 296 default="dadan",
297 help='Optimizer to use ["adam", "adam8bit"]' 297 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]'
298 ) 298 )
299 parser.add_argument( 299 parser.add_argument(
300 "--adam_beta1", 300 "--adam_beta1",
diff --git a/training/functional.py b/training/functional.py
index 43ee356..77f056e 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -424,6 +424,15 @@ def train_loop(
424 on_sample = callbacks.on_sample 424 on_sample = callbacks.on_sample
425 on_checkpoint = callbacks.on_checkpoint 425 on_checkpoint = callbacks.on_checkpoint
426 426
427 isDadaptation = False
428
429 try:
430 import dadaptation
431
432 isDadaptation = isinstance(optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan))
433 except ImportError:
434 pass
435
427 try: 436 try:
428 for epoch in range(num_epochs): 437 for epoch in range(num_epochs):
429 if accelerator.is_main_process: 438 if accelerator.is_main_process:
@@ -461,6 +470,11 @@ def train_loop(
461 "train/cur_acc": acc.item(), 470 "train/cur_acc": acc.item(),
462 "lr": lr_scheduler.get_last_lr()[0], 471 "lr": lr_scheduler.get_last_lr()[0],
463 } 472 }
473 if isDadaptation:
474 logs["lr/d*lr"] = (
475 optimizer.param_groups[0]["d"] *
476 optimizer.param_groups[0]["lr"]
477 )
464 logs.update(on_log()) 478 logs.update(on_log())
465 479
466 local_progress_bar.set_postfix(**logs) 480 local_progress_bar.set_postfix(**logs)