From f5b86b44565aaaa92543989a85ea5d88ca9b1c0c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 11:02:47 +0200 Subject: Fix --- training/strategy/lora.py | 7 ++++--- training/strategy/ti.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index d51a2f3..6730dc9 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -85,15 +85,16 @@ def lora_strategy_callbacks( ) if use_emb_decay: - return torch.stack([ + params = [ p for p in text_encoder.text_model.embeddings.token_override_embedding.params if p.grad is not None - ]) + ] + return torch.stack(params) if len(params) != 0 else None @torch.no_grad() def on_after_optimize(w, lr: float): - if use_emb_decay: + if use_emb_decay and w is not None: lambda_ = emb_decay * lr if lambda_ != 0: diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 9df160a..55e9934 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -107,18 +107,19 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_before_optimize(lr: float, epoch: int): if use_emb_decay: - return torch.stack([ + params = [ p for p in text_encoder.text_model.embeddings.token_override_embedding.params if p.grad is not None - ]) + ] + return torch.stack(params) if len(params) != 0 else None @torch.no_grad() def on_after_optimize(w, lr: float): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) - if use_emb_decay: + if use_emb_decay and w is not None: lambda_ = emb_decay * lr if lambda_ != 0: -- cgit v1.2.3-70-g09d2