From a08cd4b1581ca195f619e8bdb6cb6448287d4d2f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Apr 2023 12:39:17 +0200 Subject: Bring back Lion optimizer --- environment.yaml | 1 + train_dreambooth.py | 30 +++++++++++++++++++++++++++--- train_lora.py | 30 +++++++++++++++++++++++++++--- train_ti.py | 30 +++++++++++++++++++++++++++--- 4 files changed, 82 insertions(+), 9 deletions(-) diff --git a/environment.yaml b/environment.yaml index 8868532..1de76bd 100644 --- a/environment.yaml +++ b/environment.yaml @@ -18,6 +18,7 @@ dependencies: - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation - accelerate==0.17.1 - bitsandbytes==0.37.2 + - lion-pytorch==0.0.7 - peft==0.2.0 - python-slugify>=6.1.2 - safetensors==0.3.0 diff --git a/train_dreambooth.py b/train_dreambooth.py index 48b7926..be7d6fe 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -306,7 +306,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' + help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' ) parser.add_argument( "--dadaptation_d0", @@ -317,13 +317,13 @@ def parse_args(): parser.add_argument( "--adam_beta1", type=float, - default=0.9, + default=None, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, - default=0.999, + default=None, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( @@ -450,6 +450,18 @@ def parse_args(): if args.output_dir is None: raise ValueError("You must specify --output_dir") + if args.adam_beta1 is None: + if args.optimizer in ('adam', 'adam8bit'): + args.adam_beta1 = 0.9 + elif args.optimizer == 'lion': + args.adam_beta1 = 0.95 + + if args.adam_beta2 is None: + if args.optimizer in ('adam', 'adam8bit'): + args.adam_beta2 = 0.999 + elif args.optimizer == 'lion': + args.adam_beta2 = 0.98 + return args @@ -536,6 +548,18 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'lion': + try: + import lion_pytorch + except ImportError: + raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") + + create_optimizer = partial( + lion_pytorch.Lion, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + use_triton=True, + ) elif args.optimizer == 'adafactor': create_optimizer = partial( transformers.optimization.Adafactor, diff --git a/train_lora.py b/train_lora.py index cf73645..a0cd174 100644 --- a/train_lora.py +++ b/train_lora.py @@ -318,7 +318,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' + help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' ) parser.add_argument( "--dadaptation_d0", @@ -329,13 +329,13 @@ def parse_args(): parser.add_argument( "--adam_beta1", type=float, - default=0.9, + default=None, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, - default=0.999, + default=None, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( @@ -468,6 +468,18 @@ def parse_args(): if args.output_dir is None: raise ValueError("You must specify --output_dir") + if args.adam_beta1 is None: + if args.optimizer in ('adam', 'adam8bit'): + args.adam_beta1 = 0.9 + elif args.optimizer == 'lion': + args.adam_beta1 = 0.95 + + if args.adam_beta2 is None: + if args.optimizer in ('adam', 'adam8bit'): + args.adam_beta2 = 0.999 + elif args.optimizer == 'lion': + args.adam_beta2 = 0.98 + return args @@ -568,6 +580,18 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'lion': + try: + import lion_pytorch + except ImportError: + raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") + + create_optimizer = partial( + lion_pytorch.Lion, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + use_triton=True, + ) elif args.optimizer == 'adafactor': create_optimizer = partial( transformers.optimization.Adafactor, diff --git a/train_ti.py b/train_ti.py index 651dfbe..c242625 100644 --- a/train_ti.py +++ b/train_ti.py @@ -330,7 +330,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' + help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' ) parser.add_argument( "--dadaptation_d0", @@ -341,13 +341,13 @@ def parse_args(): parser.add_argument( "--adam_beta1", type=float, - default=0.9, + default=None, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, - default=0.999, + default=None, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( @@ -566,6 +566,18 @@ def parse_args(): if args.output_dir is None: raise ValueError("You must specify --output_dir") + if args.adam_beta1 is None: + if args.optimizer in ('adam', 'adam8bit'): + args.adam_beta1 = 0.9 + elif args.optimizer == 'lion': + args.adam_beta1 = 0.95 + + if args.adam_beta2 is None: + if args.optimizer in ('adam', 'adam8bit'): + args.adam_beta2 = 0.999 + elif args.optimizer == 'lion': + args.adam_beta2 = 0.98 + return args @@ -666,6 +678,18 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'lion': + try: + import lion_pytorch + except ImportError: + raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") + + create_optimizer = partial( + lion_pytorch.Lion, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + use_triton=True, + ) elif args.optimizer == 'adafactor': create_optimizer = partial( transformers.optimization.Adafactor, -- cgit v1.2.3-70-g09d2