From f5b86b44565aaaa92543989a85ea5d88ca9b1c0c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 11:02:47 +0200 Subject: Fix --- training/strategy/lora.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'training/strategy/lora.py') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index d51a2f3..6730dc9 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -85,15 +85,16 @@ def lora_strategy_callbacks( ) if use_emb_decay: - return torch.stack([ + params = [ p for p in text_encoder.text_model.embeddings.token_override_embedding.params if p.grad is not None - ]) + ] + return torch.stack(params) if len(params) != 0 else None @torch.no_grad() def on_after_optimize(w, lr: float): - if use_emb_decay: + if use_emb_decay and w is not None: lambda_ = emb_decay * lr if lambda_ != 0: -- cgit v1.2.3-54-g00ecf