From c40170386fd055f715db90886f0ac0da5c575bd9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 27 Mar 2023 13:19:05 +0200 Subject: Fix TI --- models/clip/embeddings.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 2d60c28..e8cc865 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -38,18 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor + self.init_temp_embeddings() + def init_temp_embeddings(self): self.temp_token_embedding = nn.Embedding( - self.token_embedding.num_embeddings, + 0, self.token_embedding.embedding_dim, device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype ) - self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() self.temp_token_ids = torch.tensor([], dtype=torch.long) def resize(self, size: int): - self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): @@ -74,9 +74,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): ) token_ids = torch.tensor(token_ids, dtype=torch.long) - self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) - self.temp_token_embedding.weight.data[token_ids] = initializer + + self.temp_token_embedding = resize_embedding( + self.temp_token_embedding, + self.temp_token_ids.shape[0], + self.initializer_factor + ) + + mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) + self.temp_token_embedding.weight.data[mask] = initializer + self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: @@ -86,17 +94,25 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] - self.temp_token_ids = torch.tensor([], dtype=torch.long) + self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[:] + self.init_temp_embeddings() def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) + all_temp_token_ids = self.temp_token_ids.to(input_ids.device) + embeds = self.token_embedding(input_ids) - mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) - embeds[mask] = self.temp_token_embedding(input_ids)[mask] + embeds_mask = torch.isin(input_ids, all_temp_token_ids) + temp_token_ids = input_ids[embeds_mask] + + temp_token_ids = temp_token_ids.unsqueeze(1) + all_temp_token_ids = all_temp_token_ids.unsqueeze(0) + temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() + + embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) return embeds -- cgit v1.2.3-70-g09d2