From 4f724ca8015771c55ab9f382ebec5fd8b3309eb2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 20:13:03 +0200 Subject: Added Prodigy optimizer --- environment.yaml | 1 + train_dreambooth.py | 16 ++++++++++++++++ training/functional.py | 10 +++++++++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/environment.yaml b/environment.yaml index 2c81a90..4a73688 100644 --- a/environment.yaml +++ b/environment.yaml @@ -27,6 +27,7 @@ dependencies: - bitsandbytes==0.39.1 - lion-pytorch==0.0.7 - peft==0.3.0 + - prodigyopt==1.0 - python-slugify>=6.1.2 - safetensors==0.3.1 - setuptools==65.6.3 diff --git a/train_dreambooth.py b/train_dreambooth.py index 84197c8..beb65fc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -348,6 +348,7 @@ def parse_args(): "dadan", "dlion", "adafactor", + "prodigy", ], help="Optimizer to use", ) @@ -828,6 +829,21 @@ def main(): args.learning_rate_text = 1.0 elif args.optimizer == "dlion": raise ImportError("DLion has not been merged into dadaptation yet") + elif args.optimizer == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError( + "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." + ) + + create_optimizer = partial( + prodigyopt.Prodigy, + weight_decay=args.adam_weight_decay, + ) + + args.learning_rate_unet = 1.0 + args.learning_rate_text = 1.0 else: raise ValueError(f'Unknown --optimizer "{args.optimizer}"') diff --git a/training/functional.py b/training/functional.py index 34a701b..cc079ef 100644 --- a/training/functional.py +++ b/training/functional.py @@ -525,6 +525,7 @@ def train_loop( on_checkpoint = callbacks.on_checkpoint isDadaptation = False + isProdigy = False try: import dadaptation @@ -535,6 +536,13 @@ def train_loop( except ImportError: pass + try: + import prodigyopt + + isProdigy = isinstance(optimizer.optimizer, prodigyopt.Prodigy) + except ImportError: + pass + num_training_steps += global_step_offset global_step += global_step_offset @@ -582,7 +590,7 @@ def train_loop( lr = lr.item() label = group_labels[i] if i < len(group_labels) else f"{i}" logs[f"lr/{label}"] = lr - if isDadaptation: + if isDadaptation or isProdigy: lr = ( optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] -- cgit v1.2.3-54-g00ecf