diff options
-rw-r--r-- | environment.yaml | 2 | ||||
-rw-r--r-- | train_dreambooth.py | 28 | ||||
-rw-r--r-- | train_lora.py | 28 | ||||
-rw-r--r-- | train_ti.py | 28 |
4 files changed, 85 insertions, 1 deletions
diff --git a/environment.yaml b/environment.yaml index db43bd5..42b568f 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -11,10 +11,10 @@ dependencies: | |||
11 | - python=3.10.8 | 11 | - python=3.10.8 |
12 | - pytorch=2.0.0=*cuda11.8* | 12 | - pytorch=2.0.0=*cuda11.8* |
13 | - torchvision=0.15.0 | 13 | - torchvision=0.15.0 |
14 | # - xformers=0.0.17.dev476 | ||
15 | - pip: | 14 | - pip: |
16 | - -e . | 15 | - -e . |
17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 16 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
17 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | ||
18 | - accelerate==0.17.1 | 18 | - accelerate==0.17.1 |
19 | - bitsandbytes==0.37.1 | 19 | - bitsandbytes==0.37.1 |
20 | - peft==0.2.0 | 20 | - peft==0.2.0 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index dd2bf6e..b706d07 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -512,6 +512,34 @@ def main(): | |||
512 | eps=args.adam_epsilon, | 512 | eps=args.adam_epsilon, |
513 | amsgrad=args.adam_amsgrad, | 513 | amsgrad=args.adam_amsgrad, |
514 | ) | 514 | ) |
515 | elif args.optimizer == 'dadam': | ||
516 | try: | ||
517 | import dadaptation | ||
518 | except ImportError: | ||
519 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | ||
520 | |||
521 | create_optimizer = partial( | ||
522 | dadaptation.DAdaptAdam, | ||
523 | betas=(args.adam_beta1, args.adam_beta2), | ||
524 | weight_decay=args.adam_weight_decay, | ||
525 | eps=args.adam_epsilon, | ||
526 | decouple=True, | ||
527 | ) | ||
528 | |||
529 | args.learning_rate = 1.0 | ||
530 | elif args.optimizer == 'dadan': | ||
531 | try: | ||
532 | import dadaptation | ||
533 | except ImportError: | ||
534 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | ||
535 | |||
536 | create_optimizer = partial( | ||
537 | dadaptation.DAdaptAdan, | ||
538 | weight_decay=args.adam_weight_decay, | ||
539 | eps=args.adam_epsilon, | ||
540 | ) | ||
541 | |||
542 | args.learning_rate = 1.0 | ||
515 | else: | 543 | else: |
516 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 544 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
517 | 545 | ||
diff --git a/train_lora.py b/train_lora.py index 2a798f3..ce8fb50 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -476,6 +476,34 @@ def main(): | |||
476 | eps=args.adam_epsilon, | 476 | eps=args.adam_epsilon, |
477 | amsgrad=args.adam_amsgrad, | 477 | amsgrad=args.adam_amsgrad, |
478 | ) | 478 | ) |
479 | elif args.optimizer == 'dadam': | ||
480 | try: | ||
481 | import dadaptation | ||
482 | except ImportError: | ||
483 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | ||
484 | |||
485 | create_optimizer = partial( | ||
486 | dadaptation.DAdaptAdam, | ||
487 | betas=(args.adam_beta1, args.adam_beta2), | ||
488 | weight_decay=args.adam_weight_decay, | ||
489 | eps=args.adam_epsilon, | ||
490 | decouple=True, | ||
491 | ) | ||
492 | |||
493 | args.learning_rate = 1.0 | ||
494 | elif args.optimizer == 'dadan': | ||
495 | try: | ||
496 | import dadaptation | ||
497 | except ImportError: | ||
498 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | ||
499 | |||
500 | create_optimizer = partial( | ||
501 | dadaptation.DAdaptAdan, | ||
502 | weight_decay=args.adam_weight_decay, | ||
503 | eps=args.adam_epsilon, | ||
504 | ) | ||
505 | |||
506 | args.learning_rate = 1.0 | ||
479 | else: | 507 | else: |
480 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 508 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
481 | 509 | ||
diff --git a/train_ti.py b/train_ti.py index 2e92ae4..ee65b44 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -607,6 +607,34 @@ def main(): | |||
607 | eps=args.adam_epsilon, | 607 | eps=args.adam_epsilon, |
608 | amsgrad=args.adam_amsgrad, | 608 | amsgrad=args.adam_amsgrad, |
609 | ) | 609 | ) |
610 | elif args.optimizer == 'dadam': | ||
611 | try: | ||
612 | import dadaptation | ||
613 | except ImportError: | ||
614 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | ||
615 | |||
616 | create_optimizer = partial( | ||
617 | dadaptation.DAdaptAdam, | ||
618 | betas=(args.adam_beta1, args.adam_beta2), | ||
619 | weight_decay=args.adam_weight_decay, | ||
620 | eps=args.adam_epsilon, | ||
621 | decouple=True, | ||
622 | ) | ||
623 | |||
624 | args.learning_rate = 1.0 | ||
625 | elif args.optimizer == 'dadan': | ||
626 | try: | ||
627 | import dadaptation | ||
628 | except ImportError: | ||
629 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | ||
630 | |||
631 | create_optimizer = partial( | ||
632 | dadaptation.DAdaptAdan, | ||
633 | weight_decay=args.adam_weight_decay, | ||
634 | eps=args.adam_epsilon, | ||
635 | ) | ||
636 | |||
637 | args.learning_rate = 1.0 | ||
610 | else: | 638 | else: |
611 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 639 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
612 | 640 | ||