summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-27 07:15:46 +0200
committerVolpeon <git@volpeon.ink>2023-03-27 07:15:46 +0200
commit0e4c36889aa6b7ec13320a03728118c7c1a8e716 (patch)
tree461e63354dac6ab5b68d0f57e1569798df5bf202 /training/strategy
parentFix TI embeddings init (diff)
downloadtextual-inversion-diff-0e4c36889aa6b7ec13320a03728118c7c1a8e716.tar.gz
textual-inversion-diff-0e4c36889aa6b7ec13320a03728118c7c1a8e716.tar.bz2
textual-inversion-diff-0e4c36889aa6b7ec13320a03728118c7c1a8e716.zip
Sparse TI embeddings without sparse tensors
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/ti.py18
1 files changed, 8 insertions, 10 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index b9a5547..7ac5011 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -108,11 +108,14 @@ def textual_inversion_strategy_callbacks(
108 @torch.no_grad() 108 @torch.no_grad()
109 def on_before_optimize(lr: float, epoch: int): 109 def on_before_optimize(lr: float, epoch: int):
110 if use_emb_decay: 110 if use_emb_decay:
111 w = text_encoder.text_model.embeddings.temp_token_embedding.weight 111 return torch.stack([
112 return torch.all(w.grad == 0, dim=1) 112 t
113 for t in text_encoder.text_model.embeddings.temp_token_embedding
114 if t.grad is not None
115 ])
113 116
114 @torch.no_grad() 117 @torch.no_grad()
115 def on_after_optimize(zero_ids, lr: float): 118 def on_after_optimize(w, lr: float):
116 if ema_embeddings is not None: 119 if ema_embeddings is not None:
117 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 120 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
118 121
@@ -120,13 +123,8 @@ def textual_inversion_strategy_callbacks(
120 lambda_ = emb_decay * lr 123 lambda_ = emb_decay * lr
121 124
122 if lambda_ != 0: 125 if lambda_ != 0:
123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight 126 norm = w[:, :].norm(dim=-1, keepdim=True)
124 127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[zero_ids] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
130 128
131 def on_log(): 129 def on_log():
132 if ema_embeddings is not None: 130 if ema_embeddings is not None: