diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-24 20:13:03 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-24 20:13:03 +0200 |
| commit | 4f724ca8015771c55ab9f382ebec5fd8b3309eb2 (patch) | |
| tree | 70ca415e0baa76ad79337cc476d80bee091628f0 /train_dreambooth.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-4f724ca8015771c55ab9f382ebec5fd8b3309eb2.tar.gz textual-inversion-diff-4f724ca8015771c55ab9f382ebec5fd8b3309eb2.tar.bz2 textual-inversion-diff-4f724ca8015771c55ab9f382ebec5fd8b3309eb2.zip | |
Added Prodigy optimizer
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 84197c8..beb65fc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -348,6 +348,7 @@ def parse_args(): | |||
| 348 | "dadan", | 348 | "dadan", |
| 349 | "dlion", | 349 | "dlion", |
| 350 | "adafactor", | 350 | "adafactor", |
| 351 | "prodigy", | ||
| 351 | ], | 352 | ], |
| 352 | help="Optimizer to use", | 353 | help="Optimizer to use", |
| 353 | ) | 354 | ) |
| @@ -828,6 +829,21 @@ def main(): | |||
| 828 | args.learning_rate_text = 1.0 | 829 | args.learning_rate_text = 1.0 |
| 829 | elif args.optimizer == "dlion": | 830 | elif args.optimizer == "dlion": |
| 830 | raise ImportError("DLion has not been merged into dadaptation yet") | 831 | raise ImportError("DLion has not been merged into dadaptation yet") |
| 832 | elif args.optimizer == "prodigy": | ||
| 833 | try: | ||
| 834 | import prodigyopt | ||
| 835 | except ImportError: | ||
| 836 | raise ImportError( | ||
| 837 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
| 838 | ) | ||
| 839 | |||
| 840 | create_optimizer = partial( | ||
| 841 | prodigyopt.Prodigy, | ||
| 842 | weight_decay=args.adam_weight_decay, | ||
| 843 | ) | ||
| 844 | |||
| 845 | args.learning_rate_unet = 1.0 | ||
| 846 | args.learning_rate_text = 1.0 | ||
| 831 | else: | 847 | else: |
| 832 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') | 848 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
| 833 | 849 | ||
