summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py
index b8c7396..91bda5c 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -387,7 +387,7 @@ def parse_args():
387 parser.add_argument( 387 parser.add_argument(
388 "--optimizer", 388 "--optimizer",
389 type=str, 389 type=str,
390 default="dadan", 390 default="adan",
391 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 391 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
392 help='Optimizer to use' 392 help='Optimizer to use'
393 ) 393 )
@@ -412,7 +412,7 @@ def parse_args():
412 parser.add_argument( 412 parser.add_argument(
413 "--adam_weight_decay", 413 "--adam_weight_decay",
414 type=float, 414 type=float,
415 default=1e-2, 415 default=2e-2,
416 help="Weight decay to use." 416 help="Weight decay to use."
417 ) 417 )
418 parser.add_argument( 418 parser.add_argument(
@@ -780,6 +780,7 @@ def main():
780 timm.optim.Adan, 780 timm.optim.Adan,
781 weight_decay=args.adam_weight_decay, 781 weight_decay=args.adam_weight_decay,
782 eps=args.adam_epsilon, 782 eps=args.adam_epsilon,
783 no_prox=True,
783 ) 784 )
784 elif args.optimizer == 'lion': 785 elif args.optimizer == 'lion':
785 try: 786 try:
@@ -961,7 +962,7 @@ def main():
961 962
962 if len(args.placeholder_tokens) != 0: 963 if len(args.placeholder_tokens) != 0:
963 params_to_optimize.append({ 964 params_to_optimize.append({
964 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), 965 "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
965 "lr": learning_rate_emb, 966 "lr": learning_rate_emb,
966 "weight_decay": 0, 967 "weight_decay": 0,
967 }) 968 })