summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-07 11:02:47 +0200
committerVolpeon <git@volpeon.ink>2023-04-07 11:02:47 +0200
commitf5b86b44565aaaa92543989a85ea5d88ca9b1c0c (patch)
treedf02bdcf757743708001fe70e9db2c3e2b9b4af9 /training/strategy/lora.py
parentUpdate (diff)
downloadtextual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.tar.gz
textual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.tar.bz2
textual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.zip
Fix
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r--training/strategy/lora.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index d51a2f3..6730dc9 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -85,15 +85,16 @@ def lora_strategy_callbacks(
85 ) 85 )
86 86
87 if use_emb_decay: 87 if use_emb_decay:
88 return torch.stack([ 88 params = [
89 p 89 p
90 for p in text_encoder.text_model.embeddings.token_override_embedding.params 90 for p in text_encoder.text_model.embeddings.token_override_embedding.params
91 if p.grad is not None 91 if p.grad is not None
92 ]) 92 ]
93 return torch.stack(params) if len(params) != 0 else None
93 94
94 @torch.no_grad() 95 @torch.no_grad()
95 def on_after_optimize(w, lr: float): 96 def on_after_optimize(w, lr: float):
96 if use_emb_decay: 97 if use_emb_decay and w is not None:
97 lambda_ = emb_decay * lr 98 lambda_ = emb_decay * lr
98 99
99 if lambda_ != 0: 100 if lambda_ != 0: