summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 84197c8..beb65fc 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -348,6 +348,7 @@ def parse_args():
348 "dadan", 348 "dadan",
349 "dlion", 349 "dlion",
350 "adafactor", 350 "adafactor",
351 "prodigy",
351 ], 352 ],
352 help="Optimizer to use", 353 help="Optimizer to use",
353 ) 354 )
@@ -828,6 +829,21 @@ def main():
828 args.learning_rate_text = 1.0 829 args.learning_rate_text = 1.0
829 elif args.optimizer == "dlion": 830 elif args.optimizer == "dlion":
830 raise ImportError("DLion has not been merged into dadaptation yet") 831 raise ImportError("DLion has not been merged into dadaptation yet")
832 elif args.optimizer == "prodigy":
833 try:
834 import prodigyopt
835 except ImportError:
836 raise ImportError(
837 "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`."
838 )
839
840 create_optimizer = partial(
841 prodigyopt.Prodigy,
842 weight_decay=args.adam_weight_decay,
843 )
844
845 args.learning_rate_unet = 1.0
846 args.learning_rate_text = 1.0
831 else: 847 else:
832 raise ValueError(f'Unknown --optimizer "{args.optimizer}"') 848 raise ValueError(f'Unknown --optimizer "{args.optimizer}"')
833 849