From 2e654c017780d37f3304436e2feb84b619f1c023 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Apr 2023 22:25:20 +0200 Subject: Improved sparse embeddings --- models/clip/embeddings.py | 52 +++++++++++++++++++---------------------------- 1 file changed, 21 insertions(+), 31 deletions(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index d8343a0..a356434 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -11,6 +11,8 @@ from transformers import CLIPTextModel from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings +from models.sparse import PseudoSparseEmbedding + def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: old_num_embeddings, old_embedding_dim = old_embedding.weight.shape @@ -41,18 +43,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - self.alpha = alpha - self.temp_token_embedding = nn.ParameterList() - self.temp_token_ids = torch.tensor([], dtype=torch.long) + self.token_override_embedding = PseudoSparseEmbedding( + self.token_embedding.embedding_dim, + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype, + ) + self.alpha = alpha def resize(self, size: int): - for _ in range(len(self.temp_token_embedding), size): - self.temp_token_embedding.append(torch.zeros( - self.token_embedding.embedding_dim, - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype, - )) + self.token_override_embedding.resize(size) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed( @@ -86,8 +86,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) - self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) self.token_embedding.weight.data[token_ids] = initializer + self.token_override_embedding.set(token_ids) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: @@ -97,33 +97,23 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): - self.token_embedding.weight.data[id] += self.alpha * emb - nn.init.zeros_(emb) - self.temp_token_ids = torch.tensor([], dtype=torch.long) + input_ids = torch.arange(self.token_embedding.num_embeddings) + embs, mask = self.token_override_embedding(input_ids) + if embs is not None: + input_ids = input_ids[mask] + self.token_embedding.weight.data[input_ids] += self.alpha * embs + self.token_override_embedding.unset(input_ids) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) - all_temp_token_ids = self.temp_token_ids.to(input_ids.device) - - embeds = self.token_embedding(input_ids) - mask = torch.isin(input_ids, all_temp_token_ids) - temp_token_ids = input_ids[mask] - - temp_token_ids = temp_token_ids.unsqueeze(1) - all_temp_token_ids = all_temp_token_ids.unsqueeze(0) - temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() - - if len(temp_token_ids): - embeds_override = torch.stack([ - self.temp_token_embedding[id] - for id in temp_token_ids - ]) - embeds[mask] += self.alpha * embeds_override + embs = self.token_embedding(input_ids) + embs_override, mask = self.token_override_embedding(input_ids) + if embs_override is not None: + embs[mask] += self.alpha * embs_override - return embeds + return embs def forward( self, -- cgit v1.2.3-70-g09d2