diff options
author | Volpeon <git@volpeon.ink> | 2023-03-21 17:13:53 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-21 17:13:53 +0100 |
commit | 07c99baaf18f2b8e98b5f7d9cce2088600e63a7f (patch) | |
tree | f851213c723ceb39ad3ac26df23e8bf127c60db7 | |
parent | Added dadaptation (diff) | |
download | textual-inversion-diff-07c99baaf18f2b8e98b5f7d9cce2088600e63a7f.tar.gz textual-inversion-diff-07c99baaf18f2b8e98b5f7d9cce2088600e63a7f.tar.bz2 textual-inversion-diff-07c99baaf18f2b8e98b5f7d9cce2088600e63a7f.zip |
Log DAdam/DAdan d
-rw-r--r-- | train_dreambooth.py | 4 | ||||
-rw-r--r-- | train_lora.py | 4 | ||||
-rw-r--r-- | train_ti.py | 4 | ||||
-rw-r--r-- | training/functional.py | 14 |
4 files changed, 20 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index b706d07..f8f6e84 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -287,8 +287,8 @@ def parse_args(): | |||
287 | parser.add_argument( | 287 | parser.add_argument( |
288 | "--optimizer", | 288 | "--optimizer", |
289 | type=str, | 289 | type=str, |
290 | default="adam", | 290 | default="dadan", |
291 | help='Optimizer to use ["adam", "adam8bit"]' | 291 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
292 | ) | 292 | ) |
293 | parser.add_argument( | 293 | parser.add_argument( |
294 | "--adam_beta1", | 294 | "--adam_beta1", |
diff --git a/train_lora.py b/train_lora.py index ce8fb50..787f271 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -245,8 +245,8 @@ def parse_args(): | |||
245 | parser.add_argument( | 245 | parser.add_argument( |
246 | "--optimizer", | 246 | "--optimizer", |
247 | type=str, | 247 | type=str, |
248 | default="adam", | 248 | default="dadan", |
249 | help='Optimizer to use ["adam", "adam8bit"]' | 249 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
250 | ) | 250 | ) |
251 | parser.add_argument( | 251 | parser.add_argument( |
252 | "--adam_beta1", | 252 | "--adam_beta1", |
diff --git a/train_ti.py b/train_ti.py index ee65b44..036c288 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -293,8 +293,8 @@ def parse_args(): | |||
293 | parser.add_argument( | 293 | parser.add_argument( |
294 | "--optimizer", | 294 | "--optimizer", |
295 | type=str, | 295 | type=str, |
296 | default="adam", | 296 | default="dadan", |
297 | help='Optimizer to use ["adam", "adam8bit"]' | 297 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
298 | ) | 298 | ) |
299 | parser.add_argument( | 299 | parser.add_argument( |
300 | "--adam_beta1", | 300 | "--adam_beta1", |
diff --git a/training/functional.py b/training/functional.py index 43ee356..77f056e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -424,6 +424,15 @@ def train_loop( | |||
424 | on_sample = callbacks.on_sample | 424 | on_sample = callbacks.on_sample |
425 | on_checkpoint = callbacks.on_checkpoint | 425 | on_checkpoint = callbacks.on_checkpoint |
426 | 426 | ||
427 | isDadaptation = False | ||
428 | |||
429 | try: | ||
430 | import dadaptation | ||
431 | |||
432 | isDadaptation = isinstance(optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) | ||
433 | except ImportError: | ||
434 | pass | ||
435 | |||
427 | try: | 436 | try: |
428 | for epoch in range(num_epochs): | 437 | for epoch in range(num_epochs): |
429 | if accelerator.is_main_process: | 438 | if accelerator.is_main_process: |
@@ -461,6 +470,11 @@ def train_loop( | |||
461 | "train/cur_acc": acc.item(), | 470 | "train/cur_acc": acc.item(), |
462 | "lr": lr_scheduler.get_last_lr()[0], | 471 | "lr": lr_scheduler.get_last_lr()[0], |
463 | } | 472 | } |
473 | if isDadaptation: | ||
474 | logs["lr/d*lr"] = ( | ||
475 | optimizer.param_groups[0]["d"] * | ||
476 | optimizer.param_groups[0]["lr"] | ||
477 | ) | ||
464 | logs.update(on_log()) | 478 | logs.update(on_log()) |
465 | 479 | ||
466 | local_progress_bar.set_postfix(**logs) | 480 | local_progress_bar.set_postfix(**logs) |