diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index 538a7f7..73b3e19 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -317,7 +317,7 @@ def parse_args(): | |||
317 | "--optimizer", | 317 | "--optimizer", |
318 | type=str, | 318 | type=str, |
319 | default="dadan", | 319 | default="dadan", |
320 | choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], | 320 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
321 | help='Optimizer to use' | 321 | help='Optimizer to use' |
322 | ) | 322 | ) |
323 | parser.add_argument( | 323 | parser.add_argument( |
@@ -544,8 +544,6 @@ def main(): | |||
544 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 544 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
545 | raise ValueError("--embeddings_dir must point to an existing directory") | 545 | raise ValueError("--embeddings_dir must point to an existing directory") |
546 | 546 | ||
547 | embeddings.persist() | ||
548 | |||
549 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 547 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
550 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 548 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
551 | 549 | ||
@@ -580,6 +578,17 @@ def main(): | |||
580 | eps=args.adam_epsilon, | 578 | eps=args.adam_epsilon, |
581 | amsgrad=args.adam_amsgrad, | 579 | amsgrad=args.adam_amsgrad, |
582 | ) | 580 | ) |
581 | elif args.optimizer == 'adan': | ||
582 | try: | ||
583 | import timm.optim | ||
584 | except ImportError: | ||
585 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | ||
586 | |||
587 | create_optimizer = partial( | ||
588 | timm.optim.Adan, | ||
589 | weight_decay=args.adam_weight_decay, | ||
590 | eps=args.adam_epsilon, | ||
591 | ) | ||
583 | elif args.optimizer == 'lion': | 592 | elif args.optimizer == 'lion': |
584 | try: | 593 | try: |
585 | import lion_pytorch | 594 | import lion_pytorch |