summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index 507d710..12e3644 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -292,7 +292,7 @@ def parse_args():
292 parser.add_argument( 292 parser.add_argument(
293 "--optimizer", 293 "--optimizer",
294 type=str, 294 type=str,
295 default="lion", 295 default="adam",
296 help='Optimizer to use ["adam", "adam8bit", "lion"]' 296 help='Optimizer to use ["adam", "adam8bit", "lion"]'
297 ) 297 )
298 parser.add_argument( 298 parser.add_argument(
@@ -586,13 +586,15 @@ def main():
586 eps=args.adam_epsilon, 586 eps=args.adam_epsilon,
587 amsgrad=args.adam_amsgrad, 587 amsgrad=args.adam_amsgrad,
588 ) 588 )
589 else: 589 elif args.optimizer == 'lion':
590 try: 590 try:
591 from lion_pytorch import Lion 591 from lion_pytorch import Lion
592 except ImportError: 592 except ImportError:
593 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") 593 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
594 594
595 create_optimizer = partial(Lion, use_triton=True) 595 create_optimizer = partial(Lion, use_triton=True)
596 else:
597 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
596 598
597 checkpoint_output_dir = output_dir/"checkpoints" 599 checkpoint_output_dir = output_dir/"checkpoints"
598 600