summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml2
-rw-r--r--train_dreambooth.py28
-rw-r--r--train_lora.py28
-rw-r--r--train_ti.py28
4 files changed, 85 insertions, 1 deletions
diff --git a/environment.yaml b/environment.yaml
index db43bd5..42b568f 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -11,10 +11,10 @@ dependencies:
11 - python=3.10.8 11 - python=3.10.8
12 - pytorch=2.0.0=*cuda11.8* 12 - pytorch=2.0.0=*cuda11.8*
13 - torchvision=0.15.0 13 - torchvision=0.15.0
14 # - xformers=0.0.17.dev476
15 - pip: 14 - pip:
16 - -e . 15 - -e .
17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 16 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
17 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation
18 - accelerate==0.17.1 18 - accelerate==0.17.1
19 - bitsandbytes==0.37.1 19 - bitsandbytes==0.37.1
20 - peft==0.2.0 20 - peft==0.2.0
diff --git a/train_dreambooth.py b/train_dreambooth.py
index dd2bf6e..b706d07 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -512,6 +512,34 @@ def main():
512 eps=args.adam_epsilon, 512 eps=args.adam_epsilon,
513 amsgrad=args.adam_amsgrad, 513 amsgrad=args.adam_amsgrad,
514 ) 514 )
515 elif args.optimizer == 'dadam':
516 try:
517 import dadaptation
518 except ImportError:
519 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.")
520
521 create_optimizer = partial(
522 dadaptation.DAdaptAdam,
523 betas=(args.adam_beta1, args.adam_beta2),
524 weight_decay=args.adam_weight_decay,
525 eps=args.adam_epsilon,
526 decouple=True,
527 )
528
529 args.learning_rate = 1.0
530 elif args.optimizer == 'dadan':
531 try:
532 import dadaptation
533 except ImportError:
534 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.")
535
536 create_optimizer = partial(
537 dadaptation.DAdaptAdan,
538 weight_decay=args.adam_weight_decay,
539 eps=args.adam_epsilon,
540 )
541
542 args.learning_rate = 1.0
515 else: 543 else:
516 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 544 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
517 545
diff --git a/train_lora.py b/train_lora.py
index 2a798f3..ce8fb50 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -476,6 +476,34 @@ def main():
476 eps=args.adam_epsilon, 476 eps=args.adam_epsilon,
477 amsgrad=args.adam_amsgrad, 477 amsgrad=args.adam_amsgrad,
478 ) 478 )
479 elif args.optimizer == 'dadam':
480 try:
481 import dadaptation
482 except ImportError:
483 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.")
484
485 create_optimizer = partial(
486 dadaptation.DAdaptAdam,
487 betas=(args.adam_beta1, args.adam_beta2),
488 weight_decay=args.adam_weight_decay,
489 eps=args.adam_epsilon,
490 decouple=True,
491 )
492
493 args.learning_rate = 1.0
494 elif args.optimizer == 'dadan':
495 try:
496 import dadaptation
497 except ImportError:
498 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.")
499
500 create_optimizer = partial(
501 dadaptation.DAdaptAdan,
502 weight_decay=args.adam_weight_decay,
503 eps=args.adam_epsilon,
504 )
505
506 args.learning_rate = 1.0
479 else: 507 else:
480 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 508 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
481 509
diff --git a/train_ti.py b/train_ti.py
index 2e92ae4..ee65b44 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -607,6 +607,34 @@ def main():
607 eps=args.adam_epsilon, 607 eps=args.adam_epsilon,
608 amsgrad=args.adam_amsgrad, 608 amsgrad=args.adam_amsgrad,
609 ) 609 )
610 elif args.optimizer == 'dadam':
611 try:
612 import dadaptation
613 except ImportError:
614 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.")
615
616 create_optimizer = partial(
617 dadaptation.DAdaptAdam,
618 betas=(args.adam_beta1, args.adam_beta2),
619 weight_decay=args.adam_weight_decay,
620 eps=args.adam_epsilon,
621 decouple=True,
622 )
623
624 args.learning_rate = 1.0
625 elif args.optimizer == 'dadan':
626 try:
627 import dadaptation
628 except ImportError:
629 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.")
630
631 create_optimizer = partial(
632 dadaptation.DAdaptAdan,
633 weight_decay=args.adam_weight_decay,
634 eps=args.adam_epsilon,
635 )
636
637 args.learning_rate = 1.0
610 else: 638 else:
611 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 639 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
612 640