From e87fd8c55397db2bdf5177c42d2013037a9b9896 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 26 Mar 2023 16:34:31 +0200 Subject: Fix TI embeddings init --- models/clip/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8d01867..870ee49 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -77,7 +77,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) - mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) + mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) self.temp_token_embedding.weight.data[mask] = initializer.to( device=self.temp_token_embedding.weight.device, dtype=self.temp_token_embedding.weight.dtype, -- cgit v1.2.3-70-g09d2