diff options
author | Volpeon <git@volpeon.ink> | 2023-04-09 11:42:56 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-09 11:42:56 +0200 |
commit | 57543ee71d6ddce68bc7cec45fae45ce7d998f61 (patch) | |
tree | 2d6e3b7cdee0a38d71073d1a6a5c42e61a606308 | |
parent | Update (diff) | |
download | textual-inversion-diff-57543ee71d6ddce68bc7cec45fae45ce7d998f61.tar.gz textual-inversion-diff-57543ee71d6ddce68bc7cec45fae45ce7d998f61.tar.bz2 textual-inversion-diff-57543ee71d6ddce68bc7cec45fae45ce7d998f61.zip |
Update
-rw-r--r-- | train_lora.py | 7 |
1 files changed, 1 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py index 6e21634..54c9e7a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -307,11 +307,6 @@ def parse_args(): | |||
307 | help="Initial learning rate (after the potential warmup period) to use.", | 307 | help="Initial learning rate (after the potential warmup period) to use.", |
308 | ) | 308 | ) |
309 | parser.add_argument( | 309 | parser.add_argument( |
310 | "--train_emb", | ||
311 | action="store_true", | ||
312 | help="Keep training text embeddings.", | ||
313 | ) | ||
314 | parser.add_argument( | ||
315 | "--scale_lr", | 310 | "--scale_lr", |
316 | action="store_true", | 311 | action="store_true", |
317 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 312 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
@@ -899,7 +894,7 @@ def main(): | |||
899 | 894 | ||
900 | params_to_optimize = [] | 895 | params_to_optimize = [] |
901 | group_labels = [] | 896 | group_labels = [] |
902 | if len(args.placeholder_tokens) != 0 and args.train_emb: | 897 | if len(args.placeholder_tokens) != 0: |
903 | params_to_optimize.append({ | 898 | params_to_optimize.append({ |
904 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 899 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
905 | "lr": args.learning_rate_emb, | 900 | "lr": args.learning_rate_emb, |