diff options
| -rw-r--r-- | train_dreambooth.py | 4 | ||||
| -rw-r--r-- | train_lora.py | 4 | ||||
| -rw-r--r-- | train_ti.py | 4 | ||||
| -rw-r--r-- | training/functional.py | 14 |
4 files changed, 20 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index b706d07..f8f6e84 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -287,8 +287,8 @@ def parse_args(): | |||
| 287 | parser.add_argument( | 287 | parser.add_argument( |
| 288 | "--optimizer", | 288 | "--optimizer", |
| 289 | type=str, | 289 | type=str, |
| 290 | default="adam", | 290 | default="dadan", |
| 291 | help='Optimizer to use ["adam", "adam8bit"]' | 291 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
| 292 | ) | 292 | ) |
| 293 | parser.add_argument( | 293 | parser.add_argument( |
| 294 | "--adam_beta1", | 294 | "--adam_beta1", |
diff --git a/train_lora.py b/train_lora.py index ce8fb50..787f271 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -245,8 +245,8 @@ def parse_args(): | |||
| 245 | parser.add_argument( | 245 | parser.add_argument( |
| 246 | "--optimizer", | 246 | "--optimizer", |
| 247 | type=str, | 247 | type=str, |
| 248 | default="adam", | 248 | default="dadan", |
| 249 | help='Optimizer to use ["adam", "adam8bit"]' | 249 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
| 250 | ) | 250 | ) |
| 251 | parser.add_argument( | 251 | parser.add_argument( |
| 252 | "--adam_beta1", | 252 | "--adam_beta1", |
diff --git a/train_ti.py b/train_ti.py index ee65b44..036c288 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -293,8 +293,8 @@ def parse_args(): | |||
| 293 | parser.add_argument( | 293 | parser.add_argument( |
| 294 | "--optimizer", | 294 | "--optimizer", |
| 295 | type=str, | 295 | type=str, |
| 296 | default="adam", | 296 | default="dadan", |
| 297 | help='Optimizer to use ["adam", "adam8bit"]' | 297 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' |
| 298 | ) | 298 | ) |
| 299 | parser.add_argument( | 299 | parser.add_argument( |
| 300 | "--adam_beta1", | 300 | "--adam_beta1", |
diff --git a/training/functional.py b/training/functional.py index 43ee356..77f056e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -424,6 +424,15 @@ def train_loop( | |||
| 424 | on_sample = callbacks.on_sample | 424 | on_sample = callbacks.on_sample |
| 425 | on_checkpoint = callbacks.on_checkpoint | 425 | on_checkpoint = callbacks.on_checkpoint |
| 426 | 426 | ||
| 427 | isDadaptation = False | ||
| 428 | |||
| 429 | try: | ||
| 430 | import dadaptation | ||
| 431 | |||
| 432 | isDadaptation = isinstance(optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) | ||
| 433 | except ImportError: | ||
| 434 | pass | ||
| 435 | |||
| 427 | try: | 436 | try: |
| 428 | for epoch in range(num_epochs): | 437 | for epoch in range(num_epochs): |
| 429 | if accelerator.is_main_process: | 438 | if accelerator.is_main_process: |
| @@ -461,6 +470,11 @@ def train_loop( | |||
| 461 | "train/cur_acc": acc.item(), | 470 | "train/cur_acc": acc.item(), |
| 462 | "lr": lr_scheduler.get_last_lr()[0], | 471 | "lr": lr_scheduler.get_last_lr()[0], |
| 463 | } | 472 | } |
| 473 | if isDadaptation: | ||
| 474 | logs["lr/d*lr"] = ( | ||
| 475 | optimizer.param_groups[0]["d"] * | ||
| 476 | optimizer.param_groups[0]["lr"] | ||
| 477 | ) | ||
| 464 | logs.update(on_log()) | 478 | logs.update(on_log()) |
| 465 | 479 | ||
| 466 | local_progress_bar.set_postfix(**logs) | 480 | local_progress_bar.set_postfix(**logs) |
