diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index b8c7396..91bda5c 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -387,7 +387,7 @@ def parse_args(): | |||
387 | parser.add_argument( | 387 | parser.add_argument( |
388 | "--optimizer", | 388 | "--optimizer", |
389 | type=str, | 389 | type=str, |
390 | default="dadan", | 390 | default="adan", |
391 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 391 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
392 | help='Optimizer to use' | 392 | help='Optimizer to use' |
393 | ) | 393 | ) |
@@ -412,7 +412,7 @@ def parse_args(): | |||
412 | parser.add_argument( | 412 | parser.add_argument( |
413 | "--adam_weight_decay", | 413 | "--adam_weight_decay", |
414 | type=float, | 414 | type=float, |
415 | default=1e-2, | 415 | default=2e-2, |
416 | help="Weight decay to use." | 416 | help="Weight decay to use." |
417 | ) | 417 | ) |
418 | parser.add_argument( | 418 | parser.add_argument( |
@@ -780,6 +780,7 @@ def main(): | |||
780 | timm.optim.Adan, | 780 | timm.optim.Adan, |
781 | weight_decay=args.adam_weight_decay, | 781 | weight_decay=args.adam_weight_decay, |
782 | eps=args.adam_epsilon, | 782 | eps=args.adam_epsilon, |
783 | no_prox=True, | ||
783 | ) | 784 | ) |
784 | elif args.optimizer == 'lion': | 785 | elif args.optimizer == 'lion': |
785 | try: | 786 | try: |
@@ -961,7 +962,7 @@ def main(): | |||
961 | 962 | ||
962 | if len(args.placeholder_tokens) != 0: | 963 | if len(args.placeholder_tokens) != 0: |
963 | params_to_optimize.append({ | 964 | params_to_optimize.append({ |
964 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 965 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), |
965 | "lr": learning_rate_emb, | 966 | "lr": learning_rate_emb, |
966 | "weight_decay": 0, | 967 | "weight_decay": 0, |
967 | }) | 968 | }) |