summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-16 07:12:14 +0200
committerVolpeon <git@volpeon.ink>2023-05-16 07:12:14 +0200
commitb31fcb741432076f7e2f3ec9423ad935a08c6671 (patch)
tree2ab052d3bd617a56c4ea388c200da52cff39ba37 /training/strategy
parentFix for latest PEFT (diff)
downloadtextual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.gz
textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.tar.bz2
textual-inversion-diff-b31fcb741432076f7e2f3ec9423ad935a08c6671.zip
Support LoRA training for token embeddings
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/lora.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 0c0f633..f942b76 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -92,7 +92,7 @@ def lora_strategy_callbacks(
92 max_grad_norm 92 max_grad_norm
93 ) 93 )
94 94
95 if use_emb_decay: 95 if len(placeholder_tokens) != 0 and use_emb_decay:
96 params = [ 96 params = [
97 p 97 p
98 for p in text_encoder.text_model.embeddings.parameters() 98 for p in text_encoder.text_model.embeddings.parameters()
@@ -102,7 +102,7 @@ def lora_strategy_callbacks(
102 102
103 @torch.no_grad() 103 @torch.no_grad()
104 def on_after_optimize(w, lrs: dict[str, float]): 104 def on_after_optimize(w, lrs: dict[str, float]):
105 if use_emb_decay and w is not None and "emb" in lrs: 105 if w is not None and "emb" in lrs:
106 lr = lrs["emb"] 106 lr = lrs["emb"]
107 lambda_ = emb_decay * lr 107 lambda_ = emb_decay * lr
108 108