summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml1
-rw-r--r--train_dreambooth.py16
-rw-r--r--training/functional.py10
3 files changed, 26 insertions, 1 deletions
diff --git a/environment.yaml b/environment.yaml
index 2c81a90..4a73688 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -27,6 +27,7 @@ dependencies:
27 - bitsandbytes==0.39.1 27 - bitsandbytes==0.39.1
28 - lion-pytorch==0.0.7 28 - lion-pytorch==0.0.7
29 - peft==0.3.0 29 - peft==0.3.0
30 - prodigyopt==1.0
30 - python-slugify>=6.1.2 31 - python-slugify>=6.1.2
31 - safetensors==0.3.1 32 - safetensors==0.3.1
32 - setuptools==65.6.3 33 - setuptools==65.6.3
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 84197c8..beb65fc 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -348,6 +348,7 @@ def parse_args():
348 "dadan", 348 "dadan",
349 "dlion", 349 "dlion",
350 "adafactor", 350 "adafactor",
351 "prodigy",
351 ], 352 ],
352 help="Optimizer to use", 353 help="Optimizer to use",
353 ) 354 )
@@ -828,6 +829,21 @@ def main():
828 args.learning_rate_text = 1.0 829 args.learning_rate_text = 1.0
829 elif args.optimizer == "dlion": 830 elif args.optimizer == "dlion":
830 raise ImportError("DLion has not been merged into dadaptation yet") 831 raise ImportError("DLion has not been merged into dadaptation yet")
832 elif args.optimizer == "prodigy":
833 try:
834 import prodigyopt
835 except ImportError:
836 raise ImportError(
837 "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`."
838 )
839
840 create_optimizer = partial(
841 prodigyopt.Prodigy,
842 weight_decay=args.adam_weight_decay,
843 )
844
845 args.learning_rate_unet = 1.0
846 args.learning_rate_text = 1.0
831 else: 847 else:
832 raise ValueError(f'Unknown --optimizer "{args.optimizer}"') 848 raise ValueError(f'Unknown --optimizer "{args.optimizer}"')
833 849
diff --git a/training/functional.py b/training/functional.py
index 34a701b..cc079ef 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -525,6 +525,7 @@ def train_loop(
525 on_checkpoint = callbacks.on_checkpoint 525 on_checkpoint = callbacks.on_checkpoint
526 526
527 isDadaptation = False 527 isDadaptation = False
528 isProdigy = False
528 529
529 try: 530 try:
530 import dadaptation 531 import dadaptation
@@ -535,6 +536,13 @@ def train_loop(
535 except ImportError: 536 except ImportError:
536 pass 537 pass
537 538
539 try:
540 import prodigyopt
541
542 isProdigy = isinstance(optimizer.optimizer, prodigyopt.Prodigy)
543 except ImportError:
544 pass
545
538 num_training_steps += global_step_offset 546 num_training_steps += global_step_offset
539 global_step += global_step_offset 547 global_step += global_step_offset
540 548
@@ -582,7 +590,7 @@ def train_loop(
582 lr = lr.item() 590 lr = lr.item()
583 label = group_labels[i] if i < len(group_labels) else f"{i}" 591 label = group_labels[i] if i < len(group_labels) else f"{i}"
584 logs[f"lr/{label}"] = lr 592 logs[f"lr/{label}"] = lr
585 if isDadaptation: 593 if isDadaptation or isProdigy:
586 lr = ( 594 lr = (
587 optimizer.param_groups[i]["d"] 595 optimizer.param_groups[i]["d"]
588 * optimizer.param_groups[i]["lr"] 596 * optimizer.param_groups[i]["lr"]