From c96073646bbb638d7d78fdd7d9fdeed08d1454b5 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Sat, 1 Apr 2023 16:30:36 +0200
Subject: Experimental: TI via LoRA

---
 training/strategy/ti.py | 30 ++++--------------------------
 1 file changed, 4 insertions(+), 26 deletions(-)

(limited to 'training')

diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index b9a5547..19b8d25 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -32,9 +32,6 @@ def textual_inversion_strategy_callbacks(
     placeholder_tokens: list[str],
     placeholder_token_ids: list[list[int]],
     gradient_checkpointing: bool = False,
-    use_emb_decay: bool = False,
-    emb_decay_target: float = 0.4,
-    emb_decay: float = 1e-2,
     use_ema: bool = False,
     ema_inv_gamma: float = 1.0,
     ema_power: int = 1,
@@ -73,7 +70,7 @@ def textual_inversion_strategy_callbacks(
 
     if use_ema:
         ema_embeddings = EMAModel(
-            text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
+            text_encoder.text_model.embeddings.overlay.parameters(),
             inv_gamma=ema_inv_gamma,
             power=ema_power,
             max_value=ema_max_decay,
@@ -85,13 +82,13 @@ def textual_inversion_strategy_callbacks(
     def ema_context():
         if ema_embeddings is not None:
             return ema_embeddings.apply_temporary(
-                text_encoder.text_model.embeddings.temp_token_embedding.parameters()
+                text_encoder.text_model.embeddings.overlay.parameters()
             )
         else:
             return nullcontext()
 
     def on_accum_model():
-        return text_encoder.text_model.embeddings.temp_token_embedding
+        return text_encoder.text_model.embeddings.overlay
 
     @contextmanager
     def on_train(epoch: int):
@@ -105,28 +102,10 @@ def textual_inversion_strategy_callbacks(
         with ema_context():
             yield
 
-    @torch.no_grad()
-    def on_before_optimize(lr: float, epoch: int):
-        if use_emb_decay:
-            w = text_encoder.text_model.embeddings.temp_token_embedding.weight
-            return torch.all(w.grad == 0, dim=1)
-
     @torch.no_grad()
     def on_after_optimize(zero_ids, lr: float):
         if ema_embeddings is not None:
-            ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
-
-        if use_emb_decay:
-            lambda_ = emb_decay * lr
-
-            if lambda_ != 0:
-                w = text_encoder.text_model.embeddings.temp_token_embedding.weight
-
-                mask = torch.ones(w.shape[0], dtype=torch.bool)
-                mask[zero_ids] = False
-
-                norm = w[mask, :].norm(dim=-1, keepdim=True)
-                w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
+            ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters())
 
     def on_log():
         if ema_embeddings is not None:
@@ -171,7 +150,6 @@ def textual_inversion_strategy_callbacks(
         on_accum_model=on_accum_model,
         on_train=on_train,
         on_eval=on_eval,
-        on_before_optimize=on_before_optimize,
         on_after_optimize=on_after_optimize,
         on_log=on_log,
         on_checkpoint=on_checkpoint,
-- 
cgit v1.2.3-70-g09d2