summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py7
1 files changed, 1 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py
index 6e21634..54c9e7a 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -307,11 +307,6 @@ def parse_args():
307 help="Initial learning rate (after the potential warmup period) to use.", 307 help="Initial learning rate (after the potential warmup period) to use.",
308 ) 308 )
309 parser.add_argument( 309 parser.add_argument(
310 "--train_emb",
311 action="store_true",
312 help="Keep training text embeddings.",
313 )
314 parser.add_argument(
315 "--scale_lr", 310 "--scale_lr",
316 action="store_true", 311 action="store_true",
317 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 312 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
@@ -899,7 +894,7 @@ def main():
899 894
900 params_to_optimize = [] 895 params_to_optimize = []
901 group_labels = [] 896 group_labels = []
902 if len(args.placeholder_tokens) != 0 and args.train_emb: 897 if len(args.placeholder_tokens) != 0:
903 params_to_optimize.append({ 898 params_to_optimize.append({
904 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), 899 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
905 "lr": args.learning_rate_emb, 900 "lr": args.learning_rate_emb,