diff options
author | Volpeon <git@volpeon.ink> | 2023-06-24 20:13:03 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-24 20:13:03 +0200 |
commit | 4f724ca8015771c55ab9f382ebec5fd8b3309eb2 (patch) | |
tree | 70ca415e0baa76ad79337cc476d80bee091628f0 /train_dreambooth.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-4f724ca8015771c55ab9f382ebec5fd8b3309eb2.tar.gz textual-inversion-diff-4f724ca8015771c55ab9f382ebec5fd8b3309eb2.tar.bz2 textual-inversion-diff-4f724ca8015771c55ab9f382ebec5fd8b3309eb2.zip |
Added Prodigy optimizer
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 84197c8..beb65fc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -348,6 +348,7 @@ def parse_args(): | |||
348 | "dadan", | 348 | "dadan", |
349 | "dlion", | 349 | "dlion", |
350 | "adafactor", | 350 | "adafactor", |
351 | "prodigy", | ||
351 | ], | 352 | ], |
352 | help="Optimizer to use", | 353 | help="Optimizer to use", |
353 | ) | 354 | ) |
@@ -828,6 +829,21 @@ def main(): | |||
828 | args.learning_rate_text = 1.0 | 829 | args.learning_rate_text = 1.0 |
829 | elif args.optimizer == "dlion": | 830 | elif args.optimizer == "dlion": |
830 | raise ImportError("DLion has not been merged into dadaptation yet") | 831 | raise ImportError("DLion has not been merged into dadaptation yet") |
832 | elif args.optimizer == "prodigy": | ||
833 | try: | ||
834 | import prodigyopt | ||
835 | except ImportError: | ||
836 | raise ImportError( | ||
837 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
838 | ) | ||
839 | |||
840 | create_optimizer = partial( | ||
841 | prodigyopt.Prodigy, | ||
842 | weight_decay=args.adam_weight_decay, | ||
843 | ) | ||
844 | |||
845 | args.learning_rate_unet = 1.0 | ||
846 | args.learning_rate_text = 1.0 | ||
831 | else: | 847 | else: |
832 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') | 848 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
833 | 849 | ||