diff options
| -rw-r--r-- | train_lora.py | 7 |
1 files changed, 1 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py index 6e21634..54c9e7a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -307,11 +307,6 @@ def parse_args(): | |||
| 307 | help="Initial learning rate (after the potential warmup period) to use.", | 307 | help="Initial learning rate (after the potential warmup period) to use.", |
| 308 | ) | 308 | ) |
| 309 | parser.add_argument( | 309 | parser.add_argument( |
| 310 | "--train_emb", | ||
| 311 | action="store_true", | ||
| 312 | help="Keep training text embeddings.", | ||
| 313 | ) | ||
| 314 | parser.add_argument( | ||
| 315 | "--scale_lr", | 310 | "--scale_lr", |
| 316 | action="store_true", | 311 | action="store_true", |
| 317 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 312 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| @@ -899,7 +894,7 @@ def main(): | |||
| 899 | 894 | ||
| 900 | params_to_optimize = [] | 895 | params_to_optimize = [] |
| 901 | group_labels = [] | 896 | group_labels = [] |
| 902 | if len(args.placeholder_tokens) != 0 and args.train_emb: | 897 | if len(args.placeholder_tokens) != 0: |
| 903 | params_to_optimize.append({ | 898 | params_to_optimize.append({ |
| 904 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 899 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 905 | "lr": args.learning_rate_emb, | 900 | "lr": args.learning_rate_emb, |
