From 30b557c8e1f03b4748ac3efca599ff51d66561cb Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Tue, 4 Apr 2023 07:30:43 +0200
Subject: TI: Bring back old embedding decay

---
 training/strategy/ti.py | 22 +++++++++++++++++++++-
 1 file changed, 21 insertions(+), 1 deletion(-)

(limited to 'training/strategy')

diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 95128da..9df160a 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -31,6 +31,9 @@ def textual_inversion_strategy_callbacks(
     seed: int,
     placeholder_tokens: list[str],
     placeholder_token_ids: list[list[int]],
+    use_emb_decay: bool = False,
+    emb_decay_target: float = 0.4,
+    emb_decay: float = 1e-2,
     use_ema: bool = False,
     ema_inv_gamma: float = 1.0,
     ema_power: int = 1,
@@ -102,10 +105,26 @@ def textual_inversion_strategy_callbacks(
             yield
 
     @torch.no_grad()
-    def on_after_optimize(zero_ids, lr: float):
+    def on_before_optimize(lr: float, epoch: int):
+        if use_emb_decay:
+            return torch.stack([
+                p
+                for p in text_encoder.text_model.embeddings.token_override_embedding.params
+                if p.grad is not None
+            ])
+
+    @torch.no_grad()
+    def on_after_optimize(w, lr: float):
         if ema_embeddings is not None:
             ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters())
 
+        if use_emb_decay:
+            lambda_ = emb_decay * lr
+
+            if lambda_ != 0:
+                norm = w[:, :].norm(dim=-1, keepdim=True)
+                w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
+
     def on_log():
         if ema_embeddings is not None:
             return {"ema_decay": ema_embeddings.decay}
@@ -149,6 +168,7 @@ def textual_inversion_strategy_callbacks(
         on_accum_model=on_accum_model,
         on_train=on_train,
         on_eval=on_eval,
+        on_before_optimize=on_before_optimize,
         on_after_optimize=on_after_optimize,
         on_log=on_log,
         on_checkpoint=on_checkpoint,
-- 
cgit v1.2.3-70-g09d2