From 2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 18 Feb 2023 13:00:13 +0100 Subject: Update --- train_ti.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 507d710..12e3644 100644 --- a/train_ti.py +++ b/train_ti.py @@ -292,7 +292,7 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="lion", + default="adam", help='Optimizer to use ["adam", "adam8bit", "lion"]' ) parser.add_argument( @@ -586,13 +586,15 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - else: + elif args.optimizer == 'lion': try: from lion_pytorch import Lion except ImportError: raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") create_optimizer = partial(Lion, use_triton=True) + else: + raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") checkpoint_output_dir = output_dir/"checkpoints" -- cgit v1.2.3-54-g00ecf