From 744b87831f5e854d86c9f39c131386c3b26e9304 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Mar 2023 14:18:08 +0100 Subject: Added dadaptation --- environment.yaml | 2 +- train_dreambooth.py | 28 ++++++++++++++++++++++++++++ train_lora.py | 28 ++++++++++++++++++++++++++++ train_ti.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 1 deletion(-) diff --git a/environment.yaml b/environment.yaml index db43bd5..42b568f 100644 --- a/environment.yaml +++ b/environment.yaml @@ -11,10 +11,10 @@ dependencies: - python=3.10.8 - pytorch=2.0.0=*cuda11.8* - torchvision=0.15.0 - # - xformers=0.0.17.dev476 - pip: - -e . - -e git+https://github.com/huggingface/diffusers#egg=diffusers + - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation - accelerate==0.17.1 - bitsandbytes==0.37.1 - peft==0.2.0 diff --git a/train_dreambooth.py b/train_dreambooth.py index dd2bf6e..b706d07 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -512,6 +512,34 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'dadam': + try: + import dadaptation + except ImportError: + raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") + + create_optimizer = partial( + dadaptation.DAdaptAdam, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=True, + ) + + args.learning_rate = 1.0 + elif args.optimizer == 'dadan': + try: + import dadaptation + except ImportError: + raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") + + create_optimizer = partial( + dadaptation.DAdaptAdan, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + args.learning_rate = 1.0 else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") diff --git a/train_lora.py b/train_lora.py index 2a798f3..ce8fb50 100644 --- a/train_lora.py +++ b/train_lora.py @@ -476,6 +476,34 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'dadam': + try: + import dadaptation + except ImportError: + raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") + + create_optimizer = partial( + dadaptation.DAdaptAdam, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=True, + ) + + args.learning_rate = 1.0 + elif args.optimizer == 'dadan': + try: + import dadaptation + except ImportError: + raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") + + create_optimizer = partial( + dadaptation.DAdaptAdan, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + args.learning_rate = 1.0 else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") diff --git a/train_ti.py b/train_ti.py index 2e92ae4..ee65b44 100644 --- a/train_ti.py +++ b/train_ti.py @@ -607,6 +607,34 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'dadam': + try: + import dadaptation + except ImportError: + raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") + + create_optimizer = partial( + dadaptation.DAdaptAdam, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=True, + ) + + args.learning_rate = 1.0 + elif args.optimizer == 'dadan': + try: + import dadaptation + except ImportError: + raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") + + create_optimizer = partial( + dadaptation.DAdaptAdan, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + args.learning_rate = 1.0 else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") -- cgit v1.2.3-70-g09d2