diff options
author | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-16 07:12:14 +0200 |
commit | b31fcb741432076f7e2f3ec9423ad935a08c6671 (patch) | |
tree | 2ab052d3bd617a56c4ea388c200da52cff39ba37 /training/strategy | |
parent | Fix for latest PEFT (diff) | |
download | textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.gz textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.bz2 textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.zip |
Support LoRA training for token embeddings
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/lora.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 0c0f633..f942b76 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -92,7 +92,7 @@ def lora_strategy_callbacks( | |||
92 | max_grad_norm | 92 | max_grad_norm |
93 | ) | 93 | ) |
94 | 94 | ||
95 | if use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
96 | params = [ | 96 | params = [ |
97 | p | 97 | p |
98 | for p in text_encoder.text_model.embeddings.parameters() | 98 | for p in text_encoder.text_model.embeddings.parameters() |
@@ -102,7 +102,7 @@ def lora_strategy_callbacks( | |||
102 | 102 | ||
103 | @torch.no_grad() | 103 | @torch.no_grad() |
104 | def on_after_optimize(w, lrs: dict[str, float]): | 104 | def on_after_optimize(w, lrs: dict[str, float]): |
105 | if use_emb_decay and w is not None and "emb" in lrs: | 105 | if w is not None and "emb" in lrs: |
106 | lr = lrs["emb"] | 106 | lr = lrs["emb"] |
107 | lambda_ = emb_decay * lr | 107 | lambda_ = emb_decay * lr |
108 | 108 | ||