From b31fcb741432076f7e2f3ec9423ad935a08c6671 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 07:12:14 +0200 Subject: Support LoRA training for token embeddings --- training/strategy/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'training/strategy/lora.py') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 0c0f633..f942b76 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -92,7 +92,7 @@ def lora_strategy_callbacks( max_grad_norm ) - if use_emb_decay: + if len(placeholder_tokens) != 0 and use_emb_decay: params = [ p for p in text_encoder.text_model.embeddings.parameters() @@ -102,7 +102,7 @@ def lora_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): - if use_emb_decay and w is not None and "emb" in lrs: + if w is not None and "emb" in lrs: lr = lrs["emb"] lambda_ = emb_decay * lr -- cgit v1.2.3-54-g00ecf