From f894dfecfaa3ec17903b2ac37ac4f071408613db Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 17 Feb 2023 21:06:11 +0100 Subject: Added Lion optimizer --- train_lora.py | 38 +++++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 330bcd6..368c29b 100644 --- a/train_lora.py +++ b/train_lora.py @@ -245,9 +245,10 @@ def parse_args(): help="Minimum learning rate in the lr scheduler." ) parser.add_argument( - "--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes." + "--optimizer", + type=str, + default="lion", + help='Optimizer to use ["adam", "adam8bit", "lion"]' ) parser.add_argument( "--adam_beta1", @@ -466,15 +467,34 @@ def main(): args.learning_rate = 1e-6 args.lr_scheduler = "exponential_growth" - if args.use_8bit_adam: + if args.optimizer == 'adam8bit': try: import bitsandbytes as bnb except ImportError: raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") - optimizer_class = bnb.optim.AdamW8bit + create_optimizer = partial( + bnb.optim.AdamW8bit, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + elif args.optimizer == 'adam': + create_optimizer = partial( + torch.optim.AdamW, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) else: - optimizer_class = torch.optim.AdamW + try: + from lion_pytorch import Lion + except ImportError: + raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") + + create_optimizer = partial(Lion, use_triton=True) trainer = partial( train, @@ -516,13 +536,9 @@ def main(): ) datamodule.setup() - optimizer = optimizer_class( + optimizer = create_optimizer( lora_layers.parameters(), lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - amsgrad=args.adam_amsgrad, ) lr_scheduler = get_scheduler( -- cgit v1.2.3-54-g00ecf