summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index 6757bde..fc0d68c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -330,7 +330,7 @@ def parse_args():
330 "--optimizer", 330 "--optimizer",
331 type=str, 331 type=str,
332 default="dadan", 332 default="dadan",
333 choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], 333 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
334 help='Optimizer to use' 334 help='Optimizer to use'
335 ) 335 )
336 parser.add_argument( 336 parser.add_argument(
@@ -679,6 +679,17 @@ def main():
679 eps=args.adam_epsilon, 679 eps=args.adam_epsilon,
680 amsgrad=args.adam_amsgrad, 680 amsgrad=args.adam_amsgrad,
681 ) 681 )
682 elif args.optimizer == 'adan':
683 try:
684 import timm.optim
685 except ImportError:
686 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.")
687
688 create_optimizer = partial(
689 timm.optim.Adan,
690 weight_decay=args.adam_weight_decay,
691 eps=args.adam_epsilon,
692 )
682 elif args.optimizer == 'lion': 693 elif args.optimizer == 'lion':
683 try: 694 try:
684 import lion_pytorch 695 import lion_pytorch