diff options
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r-- | training/strategy/ti.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 732cd74..bd0d178 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -130,7 +130,7 @@ def textual_inversion_strategy_callbacks( | |||
130 | if lambda_ != 0: | 130 | if lambda_ != 0: |
131 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 131 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
132 | 132 | ||
133 | mask = torch.zeros(w.size(0), dtype=torch.bool) | 133 | mask = torch.zeros(w.shape[0], dtype=torch.bool) |
134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | 134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True |
135 | mask[zero_ids] = False | 135 | mask[zero_ids] = False |
136 | 136 | ||