From 4e77936376e7a1fa9b16ccc6af6650233825161c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 08:46:47 +0200 Subject: Fix TI --- training/strategy/ti.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'training/strategy') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index d735dac..720ebf3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -114,11 +114,12 @@ def textual_inversion_strategy_callbacks( return torch.stack(params) if len(params) != 0 else None @torch.no_grad() - def on_after_optimize(w, lr: float): + def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) if use_emb_decay and w is not None: + lr = lrs["emb"] or lrs["0"] lambda_ = emb_decay * lr if lambda_ != 0: -- cgit v1.2.3-70-g09d2