diff options
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8d01867..870ee49 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -77,7 +77,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 77 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 77 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 78 | 78 | ||
| 79 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 79 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 80 | mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) | 80 | mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) |
| 81 | self.temp_token_embedding.weight.data[mask] = initializer.to( | 81 | self.temp_token_embedding.weight.data[mask] = initializer.to( |
| 82 | device=self.temp_token_embedding.weight.device, | 82 | device=self.temp_token_embedding.weight.device, |
| 83 | dtype=self.temp_token_embedding.weight.dtype, | 83 | dtype=self.temp_token_embedding.weight.dtype, |
