diff options
| -rw-r--r-- | environment.yaml | 2 | ||||
| -rw-r--r-- | train_dreambooth.py | 28 | ||||
| -rw-r--r-- | train_lora.py | 28 | ||||
| -rw-r--r-- | train_ti.py | 28 |
4 files changed, 85 insertions, 1 deletions
diff --git a/environment.yaml b/environment.yaml index db43bd5..42b568f 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -11,10 +11,10 @@ dependencies: | |||
| 11 | - python=3.10.8 | 11 | - python=3.10.8 |
| 12 | - pytorch=2.0.0=*cuda11.8* | 12 | - pytorch=2.0.0=*cuda11.8* |
| 13 | - torchvision=0.15.0 | 13 | - torchvision=0.15.0 |
| 14 | # - xformers=0.0.17.dev476 | ||
| 15 | - pip: | 14 | - pip: |
| 16 | - -e . | 15 | - -e . |
| 17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 16 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 17 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | ||
| 18 | - accelerate==0.17.1 | 18 | - accelerate==0.17.1 |
| 19 | - bitsandbytes==0.37.1 | 19 | - bitsandbytes==0.37.1 |
| 20 | - peft==0.2.0 | 20 | - peft==0.2.0 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index dd2bf6e..b706d07 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -512,6 +512,34 @@ def main(): | |||
| 512 | eps=args.adam_epsilon, | 512 | eps=args.adam_epsilon, |
| 513 | amsgrad=args.adam_amsgrad, | 513 | amsgrad=args.adam_amsgrad, |
| 514 | ) | 514 | ) |
| 515 | elif args.optimizer == 'dadam': | ||
| 516 | try: | ||
| 517 | import dadaptation | ||
| 518 | except ImportError: | ||
| 519 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | ||
| 520 | |||
| 521 | create_optimizer = partial( | ||
| 522 | dadaptation.DAdaptAdam, | ||
| 523 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 524 | weight_decay=args.adam_weight_decay, | ||
| 525 | eps=args.adam_epsilon, | ||
| 526 | decouple=True, | ||
| 527 | ) | ||
| 528 | |||
| 529 | args.learning_rate = 1.0 | ||
| 530 | elif args.optimizer == 'dadan': | ||
| 531 | try: | ||
| 532 | import dadaptation | ||
| 533 | except ImportError: | ||
| 534 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | ||
| 535 | |||
| 536 | create_optimizer = partial( | ||
| 537 | dadaptation.DAdaptAdan, | ||
| 538 | weight_decay=args.adam_weight_decay, | ||
| 539 | eps=args.adam_epsilon, | ||
| 540 | ) | ||
| 541 | |||
| 542 | args.learning_rate = 1.0 | ||
| 515 | else: | 543 | else: |
| 516 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 544 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 517 | 545 | ||
diff --git a/train_lora.py b/train_lora.py index 2a798f3..ce8fb50 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -476,6 +476,34 @@ def main(): | |||
| 476 | eps=args.adam_epsilon, | 476 | eps=args.adam_epsilon, |
| 477 | amsgrad=args.adam_amsgrad, | 477 | amsgrad=args.adam_amsgrad, |
| 478 | ) | 478 | ) |
| 479 | elif args.optimizer == 'dadam': | ||
| 480 | try: | ||
| 481 | import dadaptation | ||
| 482 | except ImportError: | ||
| 483 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | ||
| 484 | |||
| 485 | create_optimizer = partial( | ||
| 486 | dadaptation.DAdaptAdam, | ||
| 487 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 488 | weight_decay=args.adam_weight_decay, | ||
| 489 | eps=args.adam_epsilon, | ||
| 490 | decouple=True, | ||
| 491 | ) | ||
| 492 | |||
| 493 | args.learning_rate = 1.0 | ||
| 494 | elif args.optimizer == 'dadan': | ||
| 495 | try: | ||
| 496 | import dadaptation | ||
| 497 | except ImportError: | ||
| 498 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | ||
| 499 | |||
| 500 | create_optimizer = partial( | ||
| 501 | dadaptation.DAdaptAdan, | ||
| 502 | weight_decay=args.adam_weight_decay, | ||
| 503 | eps=args.adam_epsilon, | ||
| 504 | ) | ||
| 505 | |||
| 506 | args.learning_rate = 1.0 | ||
| 479 | else: | 507 | else: |
| 480 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 508 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 481 | 509 | ||
diff --git a/train_ti.py b/train_ti.py index 2e92ae4..ee65b44 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -607,6 +607,34 @@ def main(): | |||
| 607 | eps=args.adam_epsilon, | 607 | eps=args.adam_epsilon, |
| 608 | amsgrad=args.adam_amsgrad, | 608 | amsgrad=args.adam_amsgrad, |
| 609 | ) | 609 | ) |
| 610 | elif args.optimizer == 'dadam': | ||
| 611 | try: | ||
| 612 | import dadaptation | ||
| 613 | except ImportError: | ||
| 614 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | ||
| 615 | |||
| 616 | create_optimizer = partial( | ||
| 617 | dadaptation.DAdaptAdam, | ||
| 618 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 619 | weight_decay=args.adam_weight_decay, | ||
| 620 | eps=args.adam_epsilon, | ||
| 621 | decouple=True, | ||
| 622 | ) | ||
| 623 | |||
| 624 | args.learning_rate = 1.0 | ||
| 625 | elif args.optimizer == 'dadan': | ||
| 626 | try: | ||
| 627 | import dadaptation | ||
| 628 | except ImportError: | ||
| 629 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | ||
| 630 | |||
| 631 | create_optimizer = partial( | ||
| 632 | dadaptation.DAdaptAdan, | ||
| 633 | weight_decay=args.adam_weight_decay, | ||
| 634 | eps=args.adam_epsilon, | ||
| 635 | ) | ||
| 636 | |||
| 637 | args.learning_rate = 1.0 | ||
| 610 | else: | 638 | else: |
| 611 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 639 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 612 | 640 | ||
