From b31fcb741432076f7e2f3ec9423ad935a08c6671 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Tue, 16 May 2023 07:12:14 +0200
Subject: Support LoRA training for token embeddings

---
 training/strategy/lora.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

(limited to 'training/strategy')

diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 0c0f633..f942b76 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -92,7 +92,7 @@ def lora_strategy_callbacks(
                 max_grad_norm
             )
 
-        if use_emb_decay:
+        if len(placeholder_tokens) != 0 and use_emb_decay:
             params = [
                 p
                 for p in text_encoder.text_model.embeddings.parameters()
@@ -102,7 +102,7 @@ def lora_strategy_callbacks(
 
     @torch.no_grad()
     def on_after_optimize(w, lrs: dict[str, float]):
-        if use_emb_decay and w is not None and "emb" in lrs:
+        if w is not None and "emb" in lrs:
             lr = lrs["emb"]
             lambda_ = emb_decay * lr
 
-- 
cgit v1.2.3-70-g09d2