From a08cd4b1581ca195f619e8bdb6cb6448287d4d2f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Apr 2023 12:39:17 +0200 Subject: Bring back Lion optimizer --- train_ti.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 651dfbe..c242625 100644 --- a/train_ti.py +++ b/train_ti.py @@ -330,7 +330,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' + help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' ) parser.add_argument( "--dadaptation_d0", @@ -341,13 +341,13 @@ def parse_args(): parser.add_argument( "--adam_beta1", type=float, - default=0.9, + default=None, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, - default=0.999, + default=None, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( @@ -566,6 +566,18 @@ def parse_args(): if args.output_dir is None: raise ValueError("You must specify --output_dir") + if args.adam_beta1 is None: + if args.optimizer in ('adam', 'adam8bit'): + args.adam_beta1 = 0.9 + elif args.optimizer == 'lion': + args.adam_beta1 = 0.95 + + if args.adam_beta2 is None: + if args.optimizer in ('adam', 'adam8bit'): + args.adam_beta2 = 0.999 + elif args.optimizer == 'lion': + args.adam_beta2 = 0.98 + return args @@ -666,6 +678,18 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'lion': + try: + import lion_pytorch + except ImportError: + raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") + + create_optimizer = partial( + lion_pytorch.Lion, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + use_triton=True, + ) elif args.optimizer == 'adafactor': create_optimizer = partial( transformers.optimization.Adafactor, -- cgit v1.2.3-54-g00ecf