summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-07 11:02:47 +0200
committerVolpeon <git@volpeon.ink>2023-04-07 11:02:47 +0200
commitf5b86b44565aaaa92543989a85ea5d88ca9b1c0c (patch)
treedf02bdcf757743708001fe70e9db2c3e2b9b4af9 /training/strategy
parentUpdate (diff)
downloadtextual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.tar.gz
textual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.tar.bz2
textual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.zip
Fix
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/lora.py7
-rw-r--r--training/strategy/ti.py7
2 files changed, 8 insertions, 6 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index d51a2f3..6730dc9 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -85,15 +85,16 @@ def lora_strategy_callbacks(
85 ) 85 )
86 86
87 if use_emb_decay: 87 if use_emb_decay:
88 return torch.stack([ 88 params = [
89 p 89 p
90 for p in text_encoder.text_model.embeddings.token_override_embedding.params 90 for p in text_encoder.text_model.embeddings.token_override_embedding.params
91 if p.grad is not None 91 if p.grad is not None
92 ]) 92 ]
93 return torch.stack(params) if len(params) != 0 else None
93 94
94 @torch.no_grad() 95 @torch.no_grad()
95 def on_after_optimize(w, lr: float): 96 def on_after_optimize(w, lr: float):
96 if use_emb_decay: 97 if use_emb_decay and w is not None:
97 lambda_ = emb_decay * lr 98 lambda_ = emb_decay * lr
98 99
99 if lambda_ != 0: 100 if lambda_ != 0:
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 9df160a..55e9934 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -107,18 +107,19 @@ def textual_inversion_strategy_callbacks(
107 @torch.no_grad() 107 @torch.no_grad()
108 def on_before_optimize(lr: float, epoch: int): 108 def on_before_optimize(lr: float, epoch: int):
109 if use_emb_decay: 109 if use_emb_decay:
110 return torch.stack([ 110 params = [
111 p 111 p
112 for p in text_encoder.text_model.embeddings.token_override_embedding.params 112 for p in text_encoder.text_model.embeddings.token_override_embedding.params
113 if p.grad is not None 113 if p.grad is not None
114 ]) 114 ]
115 return torch.stack(params) if len(params) != 0 else None
115 116
116 @torch.no_grad() 117 @torch.no_grad()
117 def on_after_optimize(w, lr: float): 118 def on_after_optimize(w, lr: float):
118 if ema_embeddings is not None: 119 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) 120 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters())
120 121
121 if use_emb_decay: 122 if use_emb_decay and w is not None:
122 lambda_ = emb_decay * lr 123 lambda_ = emb_decay * lr
123 124
124 if lambda_ != 0: 125 if lambda_ != 0: