diff options
author | Volpeon <git@volpeon.ink> | 2023-03-26 16:34:31 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-26 16:34:31 +0200 |
commit | e87fd8c55397db2bdf5177c42d2013037a9b9896 (patch) | |
tree | 264bd8b3d60c140d3e65abac0ed6d031403c4ccf /models/clip | |
parent | Improved TI embeddings (diff) | |
download | textual-inversion-diff-e87fd8c55397db2bdf5177c42d2013037a9b9896.tar.gz textual-inversion-diff-e87fd8c55397db2bdf5177c42d2013037a9b9896.tar.bz2 textual-inversion-diff-e87fd8c55397db2bdf5177c42d2013037a9b9896.zip |
Fix TI embeddings init
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8d01867..870ee49 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -77,7 +77,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
77 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 77 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
78 | 78 | ||
79 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 79 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
80 | mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) | 80 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) |
81 | self.temp_token_embedding.weight.data[mask] = initializer.to( | 81 | self.temp_token_embedding.weight.data[mask] = initializer.to( |
82 | device=self.temp_token_embedding.weight.device, | 82 | device=self.temp_token_embedding.weight.device, |
83 | dtype=self.temp_token_embedding.weight.dtype, | 83 | dtype=self.temp_token_embedding.weight.dtype, |