diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/ti.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 10bc6d7..b9a5547 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -122,8 +122,7 @@ def textual_inversion_strategy_callbacks( | |||
122 | if lambda_ != 0: | 122 | if lambda_ != 0: |
123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
124 | 124 | ||
125 | mask = torch.zeros(w.shape[0], dtype=torch.bool) | 125 | mask = torch.ones(w.shape[0], dtype=torch.bool) |
126 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | ||
127 | mask[zero_ids] = False | 126 | mask[zero_ids] = False |
128 | 127 | ||
129 | norm = w[mask, :].norm(dim=-1, keepdim=True) | 128 | norm = w[mask, :].norm(dim=-1, keepdim=True) |