summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-27 10:30:26 +0200
committerVolpeon <git@volpeon.ink>2023-03-27 10:30:26 +0200
commitbd8ec551c960fa069482a4b4efd764f60755716b (patch)
tree317464289a769a32c389c68d24239a7d26289b85 /training/strategy
parentRevert to regular embeddings (diff)
downloadtextual-inversion-diff-bd8ec551c960fa069482a4b4efd764f60755716b.tar.gz
textual-inversion-diff-bd8ec551c960fa069482a4b4efd764f60755716b.tar.bz2
textual-inversion-diff-bd8ec551c960fa069482a4b4efd764f60755716b.zip
Fix TI
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/ti.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 7ac5011..b9a5547 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -108,14 +108,11 @@ def textual_inversion_strategy_callbacks(
108 @torch.no_grad() 108 @torch.no_grad()
109 def on_before_optimize(lr: float, epoch: int): 109 def on_before_optimize(lr: float, epoch: int):
110 if use_emb_decay: 110 if use_emb_decay:
111 return torch.stack([ 111 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
112 t 112 return torch.all(w.grad == 0, dim=1)
113 for t in text_encoder.text_model.embeddings.temp_token_embedding
114 if t.grad is not None
115 ])
116 113
117 @torch.no_grad() 114 @torch.no_grad()
118 def on_after_optimize(w, lr: float): 115 def on_after_optimize(zero_ids, lr: float):
119 if ema_embeddings is not None: 116 if ema_embeddings is not None:
120 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 117 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
121 118
@@ -123,8 +120,13 @@ def textual_inversion_strategy_callbacks(
123 lambda_ = emb_decay * lr 120 lambda_ = emb_decay * lr
124 121
125 if lambda_ != 0: 122 if lambda_ != 0:
126 norm = w[:, :].norm(dim=-1, keepdim=True) 123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 124
125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[zero_ids] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
128 130
129 def on_log(): 131 def on_log():
130 if ema_embeddings is not None: 132 if ema_embeddings is not None: