From f4f996681ca340e940315ca0ebc162c655904a7d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 5 Apr 2023 16:02:04 +0200 Subject: Add color jitter --- train_lora.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 538a7f7..73b3e19 100644 --- a/train_lora.py +++ b/train_lora.py @@ -317,7 +317,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], + choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], help='Optimizer to use' ) parser.add_argument( @@ -544,8 +544,6 @@ def main(): if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - embeddings.persist() - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") @@ -580,6 +578,17 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'adan': + try: + import timm.optim + except ImportError: + raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") + + create_optimizer = partial( + timm.optim.Adan, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) elif args.optimizer == 'lion': try: import lion_pytorch -- cgit v1.2.3-54-g00ecf