diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-24 20:13:03 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-24 20:13:03 +0200 |
| commit | 4f724ca8015771c55ab9f382ebec5fd8b3309eb2 (patch) | |
| tree | 70ca415e0baa76ad79337cc476d80bee091628f0 | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-4f724ca8015771c55ab9f382ebec5fd8b3309eb2.tar.gz textual-inversion-diff-4f724ca8015771c55ab9f382ebec5fd8b3309eb2.tar.bz2 textual-inversion-diff-4f724ca8015771c55ab9f382ebec5fd8b3309eb2.zip | |
Added Prodigy optimizer
| -rw-r--r-- | environment.yaml | 1 | ||||
| -rw-r--r-- | train_dreambooth.py | 16 | ||||
| -rw-r--r-- | training/functional.py | 10 |
3 files changed, 26 insertions, 1 deletions
diff --git a/environment.yaml b/environment.yaml index 2c81a90..4a73688 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -27,6 +27,7 @@ dependencies: | |||
| 27 | - bitsandbytes==0.39.1 | 27 | - bitsandbytes==0.39.1 |
| 28 | - lion-pytorch==0.0.7 | 28 | - lion-pytorch==0.0.7 |
| 29 | - peft==0.3.0 | 29 | - peft==0.3.0 |
| 30 | - prodigyopt==1.0 | ||
| 30 | - python-slugify>=6.1.2 | 31 | - python-slugify>=6.1.2 |
| 31 | - safetensors==0.3.1 | 32 | - safetensors==0.3.1 |
| 32 | - setuptools==65.6.3 | 33 | - setuptools==65.6.3 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 84197c8..beb65fc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -348,6 +348,7 @@ def parse_args(): | |||
| 348 | "dadan", | 348 | "dadan", |
| 349 | "dlion", | 349 | "dlion", |
| 350 | "adafactor", | 350 | "adafactor", |
| 351 | "prodigy", | ||
| 351 | ], | 352 | ], |
| 352 | help="Optimizer to use", | 353 | help="Optimizer to use", |
| 353 | ) | 354 | ) |
| @@ -828,6 +829,21 @@ def main(): | |||
| 828 | args.learning_rate_text = 1.0 | 829 | args.learning_rate_text = 1.0 |
| 829 | elif args.optimizer == "dlion": | 830 | elif args.optimizer == "dlion": |
| 830 | raise ImportError("DLion has not been merged into dadaptation yet") | 831 | raise ImportError("DLion has not been merged into dadaptation yet") |
| 832 | elif args.optimizer == "prodigy": | ||
| 833 | try: | ||
| 834 | import prodigyopt | ||
| 835 | except ImportError: | ||
| 836 | raise ImportError( | ||
| 837 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
| 838 | ) | ||
| 839 | |||
| 840 | create_optimizer = partial( | ||
| 841 | prodigyopt.Prodigy, | ||
| 842 | weight_decay=args.adam_weight_decay, | ||
| 843 | ) | ||
| 844 | |||
| 845 | args.learning_rate_unet = 1.0 | ||
| 846 | args.learning_rate_text = 1.0 | ||
| 831 | else: | 847 | else: |
| 832 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') | 848 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
| 833 | 849 | ||
diff --git a/training/functional.py b/training/functional.py index 34a701b..cc079ef 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -525,6 +525,7 @@ def train_loop( | |||
| 525 | on_checkpoint = callbacks.on_checkpoint | 525 | on_checkpoint = callbacks.on_checkpoint |
| 526 | 526 | ||
| 527 | isDadaptation = False | 527 | isDadaptation = False |
| 528 | isProdigy = False | ||
| 528 | 529 | ||
| 529 | try: | 530 | try: |
| 530 | import dadaptation | 531 | import dadaptation |
| @@ -535,6 +536,13 @@ def train_loop( | |||
| 535 | except ImportError: | 536 | except ImportError: |
| 536 | pass | 537 | pass |
| 537 | 538 | ||
| 539 | try: | ||
| 540 | import prodigyopt | ||
| 541 | |||
| 542 | isProdigy = isinstance(optimizer.optimizer, prodigyopt.Prodigy) | ||
| 543 | except ImportError: | ||
| 544 | pass | ||
| 545 | |||
| 538 | num_training_steps += global_step_offset | 546 | num_training_steps += global_step_offset |
| 539 | global_step += global_step_offset | 547 | global_step += global_step_offset |
| 540 | 548 | ||
| @@ -582,7 +590,7 @@ def train_loop( | |||
| 582 | lr = lr.item() | 590 | lr = lr.item() |
| 583 | label = group_labels[i] if i < len(group_labels) else f"{i}" | 591 | label = group_labels[i] if i < len(group_labels) else f"{i}" |
| 584 | logs[f"lr/{label}"] = lr | 592 | logs[f"lr/{label}"] = lr |
| 585 | if isDadaptation: | 593 | if isDadaptation or isProdigy: |
| 586 | lr = ( | 594 | lr = ( |
| 587 | optimizer.param_groups[i]["d"] | 595 | optimizer.param_groups[i]["d"] |
| 588 | * optimizer.param_groups[i]["lr"] | 596 | * optimizer.param_groups[i]["lr"] |
