diff options
author | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
commit | a1b8327085ddeab589be074d7e9df4291aba1210 (patch) | |
tree | 2f2016916d7a2f659268c3e375d55c59583c2b3b /training/strategy/ti.py | |
parent | Fixed TI normalization order (diff) | |
download | textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.gz textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.bz2 textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.zip |
Update
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r-- | training/strategy/ti.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 732cd74..bd0d178 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -130,7 +130,7 @@ def textual_inversion_strategy_callbacks( | |||
130 | if lambda_ != 0: | 130 | if lambda_ != 0: |
131 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 131 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
132 | 132 | ||
133 | mask = torch.zeros(w.size(0), dtype=torch.bool) | 133 | mask = torch.zeros(w.shape[0], dtype=torch.bool) |
134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | 134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True |
135 | mask[zero_ids] = False | 135 | mask[zero_ids] = False |
136 | 136 | ||