From 4f724ca8015771c55ab9f382ebec5fd8b3309eb2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 20:13:03 +0200 Subject: Added Prodigy optimizer --- train_dreambooth.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 84197c8..beb65fc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -348,6 +348,7 @@ def parse_args(): "dadan", "dlion", "adafactor", + "prodigy", ], help="Optimizer to use", ) @@ -828,6 +829,21 @@ def main(): args.learning_rate_text = 1.0 elif args.optimizer == "dlion": raise ImportError("DLion has not been merged into dadaptation yet") + elif args.optimizer == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError( + "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." + ) + + create_optimizer = partial( + prodigyopt.Prodigy, + weight_decay=args.adam_weight_decay, + ) + + args.learning_rate_unet = 1.0 + args.learning_rate_text = 1.0 else: raise ValueError(f'Unknown --optimizer "{args.optimizer}"') -- cgit v1.2.3-54-g00ecf