diff options
-rw-r--r-- | models/clip/embeddings.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8d01867..870ee49 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -77,7 +77,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
77 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 77 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
78 | 78 | ||
79 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 79 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
80 | mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) | 80 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) |
81 | self.temp_token_embedding.weight.data[mask] = initializer.to( | 81 | self.temp_token_embedding.weight.data[mask] = initializer.to( |
82 | device=self.temp_token_embedding.weight.device, | 82 | device=self.temp_token_embedding.weight.device, |
83 | dtype=self.temp_token_embedding.weight.dtype, | 83 | dtype=self.temp_token_embedding.weight.dtype, |