diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 6757bde..fc0d68c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -330,7 +330,7 @@ def parse_args(): | |||
| 330 | "--optimizer", | 330 | "--optimizer", |
| 331 | type=str, | 331 | type=str, |
| 332 | default="dadan", | 332 | default="dadan", |
| 333 | choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], | 333 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
| 334 | help='Optimizer to use' | 334 | help='Optimizer to use' |
| 335 | ) | 335 | ) |
| 336 | parser.add_argument( | 336 | parser.add_argument( |
| @@ -679,6 +679,17 @@ def main(): | |||
| 679 | eps=args.adam_epsilon, | 679 | eps=args.adam_epsilon, |
| 680 | amsgrad=args.adam_amsgrad, | 680 | amsgrad=args.adam_amsgrad, |
| 681 | ) | 681 | ) |
| 682 | elif args.optimizer == 'adan': | ||
| 683 | try: | ||
| 684 | import timm.optim | ||
| 685 | except ImportError: | ||
| 686 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | ||
| 687 | |||
| 688 | create_optimizer = partial( | ||
| 689 | timm.optim.Adan, | ||
| 690 | weight_decay=args.adam_weight_decay, | ||
| 691 | eps=args.adam_epsilon, | ||
| 692 | ) | ||
| 682 | elif args.optimizer == 'lion': | 693 | elif args.optimizer == 'lion': |
| 683 | try: | 694 | try: |
| 684 | import lion_pytorch | 695 | import lion_pytorch |
