summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-21 17:13:53 +0100
committerVolpeon <git@volpeon.ink>2023-03-21 17:13:53 +0100
commit07c99baaf18f2b8e98b5f7d9cce2088600e63a7f (patch)
treef851213c723ceb39ad3ac26df23e8bf127c60db7 /training/functional.py
parentAdded dadaptation (diff)
downloadtextual-inversion-diff-07c99baaf18f2b8e98b5f7d9cce2088600e63a7f.tar.gz
textual-inversion-diff-07c99baaf18f2b8e98b5f7d9cce2088600e63a7f.tar.bz2
textual-inversion-diff-07c99baaf18f2b8e98b5f7d9cce2088600e63a7f.zip
Log DAdam/DAdan d
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py
index 43ee356..77f056e 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -424,6 +424,15 @@ def train_loop(
424 on_sample = callbacks.on_sample 424 on_sample = callbacks.on_sample
425 on_checkpoint = callbacks.on_checkpoint 425 on_checkpoint = callbacks.on_checkpoint
426 426
427 isDadaptation = False
428
429 try:
430 import dadaptation
431
432 isDadaptation = isinstance(optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan))
433 except ImportError:
434 pass
435
427 try: 436 try:
428 for epoch in range(num_epochs): 437 for epoch in range(num_epochs):
429 if accelerator.is_main_process: 438 if accelerator.is_main_process:
@@ -461,6 +470,11 @@ def train_loop(
461 "train/cur_acc": acc.item(), 470 "train/cur_acc": acc.item(),
462 "lr": lr_scheduler.get_last_lr()[0], 471 "lr": lr_scheduler.get_last_lr()[0],
463 } 472 }
473 if isDadaptation:
474 logs["lr/d*lr"] = (
475 optimizer.param_groups[0]["d"] *
476 optimizer.param_groups[0]["lr"]
477 )
464 logs.update(on_log()) 478 logs.update(on_log())
465 479
466 local_progress_bar.set_postfix(**logs) 480 local_progress_bar.set_postfix(**logs)