From 555912a86b012382a78f1b2717c2e0fde5994a04 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Tue, 17 Jan 2023 11:50:16 +0100
Subject: Make embedding decay work like Adam decay

---
 training/strategy/ti.py | 14 +++++---------
 1 file changed, 5 insertions(+), 9 deletions(-)

(limited to 'training')

diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 081180f..eb6730b 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -32,12 +32,10 @@ def textual_inversion_strategy_callbacks(
     seed: int,
     placeholder_tokens: list[str],
     placeholder_token_ids: list[list[int]],
-    learning_rate: float,
     gradient_checkpointing: bool = False,
     use_emb_decay: bool = False,
     emb_decay_target: float = 0.4,
-    emb_decay_factor: float = 1,
-    emb_decay_start: float = 0,
+    emb_decay: float = 1e-2,
     use_ema: bool = False,
     ema_inv_gamma: float = 1.0,
     ema_power: int = 1,
@@ -120,17 +118,15 @@ def textual_inversion_strategy_callbacks(
             yield
 
     def on_after_optimize(lr: float):
-        if ema_embeddings is not None:
-            ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
-
-    @torch.no_grad()
-    def on_after_epoch(lr: float):
         if use_emb_decay:
             text_encoder.text_model.embeddings.normalize(
                 emb_decay_target,
-                min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start))))
+                min(1.0, emb_decay * lr)
             )
 
+        if ema_embeddings is not None:
+            ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
+
     def on_log():
         if ema_embeddings is not None:
             return {"ema_decay": ema_embeddings.decay}
-- 
cgit v1.2.3-70-g09d2