From 2e654c017780d37f3304436e2feb84b619f1c023 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Apr 2023 22:25:20 +0200 Subject: Improved sparse embeddings --- models/sparse.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 models/sparse.py (limited to 'models/sparse.py') diff --git a/models/sparse.py b/models/sparse.py new file mode 100644 index 0000000..0b15454 --- /dev/null +++ b/models/sparse.py @@ -0,0 +1,57 @@ +from typing import Optional + +import torch +import torch.nn as nn + + +class PseudoSparseEmbedding(nn.Module): + def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): + super().__init__() + + self.embedding_dim = embedding_dim + self.dtype = dtype + self.params = nn.ParameterList() + self.mapping = torch.zeros(0, device=device, dtype=torch.long) + + def forward(self, input_ids: Optional[torch.LongTensor] = None): + if input_ids is None: + input_ids = torch.arange(self.mapping.shape[0]) + + ids = self.mapping[input_ids.to(self.mapping.device)] + mask = ~(ids == -1) + + if torch.all(~mask): + embs = None + else: + embs = torch.stack([self.params[id] for id in ids[mask]]) + + return embs, mask + + def resize(self, new_num_embeddings: int): + old_num_embeddings = self.mapping.shape[0] + n = min(old_num_embeddings, new_num_embeddings) + + new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 + new_mapping[:n] = self.mapping[:n] + + self.mapping = new_mapping + + def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): + if len(input_ids.shape) != 0: + if tensor is not None: + return [self.set(id, t) for id, t in zip(input_ids, tensor)] + else: + return [self.set(id) for id in input_ids] + + id = self.mapping[input_ids] + + if id == -1: + id = len(self.params) + self.mapping[input_ids] = id + self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) + + self.params[id] = tensor if tensor is not None else torch.zeros( + self.embedding_dim, device=self.mapping.device, dtype=self.dtype) + + def unset(self, input_ids: torch.LongTensor): + self.mapping[input_ids] = -1 -- cgit v1.2.3-54-g00ecf