From b5e0ef7b8a4629c2d1885a96f0faf24fafba1467 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 26 Mar 2023 14:29:57 +0200 Subject: Improved TI embeddings --- training/strategy/ti.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'training') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 10bc6d7..b9a5547 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -122,8 +122,7 @@ def textual_inversion_strategy_callbacks( if lambda_ != 0: w = text_encoder.text_model.embeddings.temp_token_embedding.weight - mask = torch.zeros(w.shape[0], dtype=torch.bool) - mask[text_encoder.text_model.embeddings.temp_token_ids] = True + mask = torch.ones(w.shape[0], dtype=torch.bool) mask[zero_ids] = False norm = w[mask, :].norm(dim=-1, keepdim=True) -- cgit v1.2.3-54-g00ecf