diff options
author | Volpeon <git@volpeon.ink> | 2023-03-26 14:29:57 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-26 14:29:57 +0200 |
commit | b5e0ef7b8a4629c2d1885a96f0faf24fafba1467 (patch) | |
tree | 675749d04db22ffca4ca0eb74449c1242c582bc4 /training | |
parent | Improved inverted tokens (diff) | |
download | textual-inversion-diff-b5e0ef7b8a4629c2d1885a96f0faf24fafba1467.tar.gz textual-inversion-diff-b5e0ef7b8a4629c2d1885a96f0faf24fafba1467.tar.bz2 textual-inversion-diff-b5e0ef7b8a4629c2d1885a96f0faf24fafba1467.zip |
Improved TI embeddings
Diffstat (limited to 'training')
-rw-r--r-- | training/strategy/ti.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 10bc6d7..b9a5547 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -122,8 +122,7 @@ def textual_inversion_strategy_callbacks( | |||
122 | if lambda_ != 0: | 122 | if lambda_ != 0: |
123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
124 | 124 | ||
125 | mask = torch.zeros(w.shape[0], dtype=torch.bool) | 125 | mask = torch.ones(w.shape[0], dtype=torch.bool) |
126 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | ||
127 | mask[zero_ids] = False | 126 | mask[zero_ids] = False |
128 | 127 | ||
129 | norm = w[mask, :].norm(dim=-1, keepdim=True) | 128 | norm = w[mask, :].norm(dim=-1, keepdim=True) |