diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-09 11:42:56 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-09 11:42:56 +0200 |
| commit | 57543ee71d6ddce68bc7cec45fae45ce7d998f61 (patch) | |
| tree | 2d6e3b7cdee0a38d71073d1a6a5c42e61a606308 | |
| parent | Update (diff) | |
| download | textual-inversion-diff-57543ee71d6ddce68bc7cec45fae45ce7d998f61.tar.gz textual-inversion-diff-57543ee71d6ddce68bc7cec45fae45ce7d998f61.tar.bz2 textual-inversion-diff-57543ee71d6ddce68bc7cec45fae45ce7d998f61.zip | |
Update
| -rw-r--r-- | train_lora.py | 7 |
1 files changed, 1 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py index 6e21634..54c9e7a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -307,11 +307,6 @@ def parse_args(): | |||
| 307 | help="Initial learning rate (after the potential warmup period) to use.", | 307 | help="Initial learning rate (after the potential warmup period) to use.", |
| 308 | ) | 308 | ) |
| 309 | parser.add_argument( | 309 | parser.add_argument( |
| 310 | "--train_emb", | ||
| 311 | action="store_true", | ||
| 312 | help="Keep training text embeddings.", | ||
| 313 | ) | ||
| 314 | parser.add_argument( | ||
| 315 | "--scale_lr", | 310 | "--scale_lr", |
| 316 | action="store_true", | 311 | action="store_true", |
| 317 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 312 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| @@ -899,7 +894,7 @@ def main(): | |||
| 899 | 894 | ||
| 900 | params_to_optimize = [] | 895 | params_to_optimize = [] |
| 901 | group_labels = [] | 896 | group_labels = [] |
| 902 | if len(args.placeholder_tokens) != 0 and args.train_emb: | 897 | if len(args.placeholder_tokens) != 0: |
| 903 | params_to_optimize.append({ | 898 | params_to_optimize.append({ |
| 904 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 899 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
| 905 | "lr": args.learning_rate_emb, | 900 | "lr": args.learning_rate_emb, |
