summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-01 12:34:42 +0100
committerVolpeon <git@volpeon.ink>2023-03-01 12:34:42 +0100
commita1b8327085ddeab589be074d7e9df4291aba1210 (patch)
tree2f2016916d7a2f659268c3e375d55c59583c2b3b /training/strategy/ti.py
parentFixed TI normalization order (diff)
downloadtextual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.gz
textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.bz2
textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.zip
Update
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 732cd74..bd0d178 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -130,7 +130,7 @@ def textual_inversion_strategy_callbacks(
130 if lambda_ != 0: 130 if lambda_ != 0:
131 w = text_encoder.text_model.embeddings.temp_token_embedding.weight 131 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
132 132
133 mask = torch.zeros(w.size(0), dtype=torch.bool) 133 mask = torch.zeros(w.shape[0], dtype=torch.bool)
134 mask[text_encoder.text_model.embeddings.temp_token_ids] = True 134 mask[text_encoder.text_model.embeddings.temp_token_ids] = True
135 mask[zero_ids] = False 135 mask[zero_ids] = False
136 136