diff options
-rw-r--r-- | environment.yaml | 1 | ||||
-rw-r--r-- | train_dreambooth.py | 16 | ||||
-rw-r--r-- | training/functional.py | 10 |
3 files changed, 26 insertions, 1 deletions
diff --git a/environment.yaml b/environment.yaml index 2c81a90..4a73688 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -27,6 +27,7 @@ dependencies: | |||
27 | - bitsandbytes==0.39.1 | 27 | - bitsandbytes==0.39.1 |
28 | - lion-pytorch==0.0.7 | 28 | - lion-pytorch==0.0.7 |
29 | - peft==0.3.0 | 29 | - peft==0.3.0 |
30 | - prodigyopt==1.0 | ||
30 | - python-slugify>=6.1.2 | 31 | - python-slugify>=6.1.2 |
31 | - safetensors==0.3.1 | 32 | - safetensors==0.3.1 |
32 | - setuptools==65.6.3 | 33 | - setuptools==65.6.3 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 84197c8..beb65fc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -348,6 +348,7 @@ def parse_args(): | |||
348 | "dadan", | 348 | "dadan", |
349 | "dlion", | 349 | "dlion", |
350 | "adafactor", | 350 | "adafactor", |
351 | "prodigy", | ||
351 | ], | 352 | ], |
352 | help="Optimizer to use", | 353 | help="Optimizer to use", |
353 | ) | 354 | ) |
@@ -828,6 +829,21 @@ def main(): | |||
828 | args.learning_rate_text = 1.0 | 829 | args.learning_rate_text = 1.0 |
829 | elif args.optimizer == "dlion": | 830 | elif args.optimizer == "dlion": |
830 | raise ImportError("DLion has not been merged into dadaptation yet") | 831 | raise ImportError("DLion has not been merged into dadaptation yet") |
832 | elif args.optimizer == "prodigy": | ||
833 | try: | ||
834 | import prodigyopt | ||
835 | except ImportError: | ||
836 | raise ImportError( | ||
837 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
838 | ) | ||
839 | |||
840 | create_optimizer = partial( | ||
841 | prodigyopt.Prodigy, | ||
842 | weight_decay=args.adam_weight_decay, | ||
843 | ) | ||
844 | |||
845 | args.learning_rate_unet = 1.0 | ||
846 | args.learning_rate_text = 1.0 | ||
831 | else: | 847 | else: |
832 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') | 848 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
833 | 849 | ||
diff --git a/training/functional.py b/training/functional.py index 34a701b..cc079ef 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -525,6 +525,7 @@ def train_loop( | |||
525 | on_checkpoint = callbacks.on_checkpoint | 525 | on_checkpoint = callbacks.on_checkpoint |
526 | 526 | ||
527 | isDadaptation = False | 527 | isDadaptation = False |
528 | isProdigy = False | ||
528 | 529 | ||
529 | try: | 530 | try: |
530 | import dadaptation | 531 | import dadaptation |
@@ -535,6 +536,13 @@ def train_loop( | |||
535 | except ImportError: | 536 | except ImportError: |
536 | pass | 537 | pass |
537 | 538 | ||
539 | try: | ||
540 | import prodigyopt | ||
541 | |||
542 | isProdigy = isinstance(optimizer.optimizer, prodigyopt.Prodigy) | ||
543 | except ImportError: | ||
544 | pass | ||
545 | |||
538 | num_training_steps += global_step_offset | 546 | num_training_steps += global_step_offset |
539 | global_step += global_step_offset | 547 | global_step += global_step_offset |
540 | 548 | ||
@@ -582,7 +590,7 @@ def train_loop( | |||
582 | lr = lr.item() | 590 | lr = lr.item() |
583 | label = group_labels[i] if i < len(group_labels) else f"{i}" | 591 | label = group_labels[i] if i < len(group_labels) else f"{i}" |
584 | logs[f"lr/{label}"] = lr | 592 | logs[f"lr/{label}"] = lr |
585 | if isDadaptation: | 593 | if isDadaptation or isProdigy: |
586 | lr = ( | 594 | lr = ( |
587 | optimizer.param_groups[i]["d"] | 595 | optimizer.param_groups[i]["d"] |
588 | * optimizer.param_groups[i]["lr"] | 596 | * optimizer.param_groups[i]["lr"] |