From 0e4c36889aa6b7ec13320a03728118c7c1a8e716 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Mon, 27 Mar 2023 07:15:46 +0200
Subject: Sparse TI embeddings without sparse tensors

---
 training/strategy/ti.py | 18 ++++++++----------
 1 file changed, 8 insertions(+), 10 deletions(-)

(limited to 'training')

diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index b9a5547..7ac5011 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -108,11 +108,14 @@ def textual_inversion_strategy_callbacks(
     @torch.no_grad()
     def on_before_optimize(lr: float, epoch: int):
         if use_emb_decay:
-            w = text_encoder.text_model.embeddings.temp_token_embedding.weight
-            return torch.all(w.grad == 0, dim=1)
+            return torch.stack([
+                t
+                for t in text_encoder.text_model.embeddings.temp_token_embedding
+                if t.grad is not None
+            ])
 
     @torch.no_grad()
-    def on_after_optimize(zero_ids, lr: float):
+    def on_after_optimize(w, lr: float):
         if ema_embeddings is not None:
             ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
 
@@ -120,13 +123,8 @@ def textual_inversion_strategy_callbacks(
             lambda_ = emb_decay * lr
 
             if lambda_ != 0:
-                w = text_encoder.text_model.embeddings.temp_token_embedding.weight
-
-                mask = torch.ones(w.shape[0], dtype=torch.bool)
-                mask[zero_ids] = False
-
-                norm = w[mask, :].norm(dim=-1, keepdim=True)
-                w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
+                norm = w[:, :].norm(dim=-1, keepdim=True)
+                w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
 
     def on_log():
         if ema_embeddings is not None:
-- 
cgit v1.2.3-70-g09d2