From a412196d1a3b616655de52fb12e0d8528e1f1af0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 27 Mar 2023 10:16:10 +0200 Subject: Revert to regular embeddings --- models/clip/embeddings.py | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 95904cf..2b315c4 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -42,16 +42,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.init_temp_embeddings() def init_temp_embeddings(self): - self.temp_token_embedding = nn.ParameterList() + self.temp_token_embedding = nn.Embedding( + 0, + self.token_embedding.embedding_dim, + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype + ) self.temp_token_ids = torch.tensor([], dtype=torch.long) def resize(self, size: int): - for _ in range(len(self.temp_token_embedding), size): - self.temp_token_embedding.append(torch.zeros( - self.token_embedding.embedding_dim, - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype, - )) + self.temp_token_embedding = resize_embedding( + self.temp_token_embedding, + size - self.num_permanent_embeddings, + self.initializer_factor + ) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): @@ -78,10 +82,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) - mask = torch.nonzero(torch.isin(self.temp_token_ids, token_ids)).squeeze(1) - - for i, id in enumerate(mask): - self.temp_token_embedding[id] = initializer[i] + mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) + self.temp_token_embedding.weight.data[mask] = initializer def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: @@ -91,8 +93,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): - self.token_embedding.weight.data[id] = emb + self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] self.num_permanent_embeddings = self.token_embedding.num_embeddings self.init_temp_embeddings() @@ -111,12 +112,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): all_temp_token_ids = all_temp_token_ids.unsqueeze(0) temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() - if len(temp_token_ids): - embeds_override = torch.stack([ - self.temp_token_embedding[id] - for id in temp_token_ids - ]) - embeds[embeds_mask] = embeds_override + embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) return embeds -- cgit v1.2.3-70-g09d2