from typing import Optional import torch import torch.nn as nn class PseudoSparseEmbedding(nn.Module): def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): super().__init__() self.embedding_dim = embedding_dim self.dtype = dtype self.params = nn.ParameterList() self.mapping = torch.zeros(0, device=device, dtype=torch.long) def forward(self, input_ids: torch.LongTensor): ids = self.mapping[input_ids.to(self.mapping.device)] mask = ~(ids == -1) if torch.all(~mask): embs = None else: embs = torch.stack([self.params[id] for id in ids[mask]]) return embs, mask def resize(self, new_num_embeddings: int): old_num_embeddings = self.mapping.shape[0] n = min(old_num_embeddings, new_num_embeddings) new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 new_mapping[:n] = self.mapping[:n] self.mapping = new_mapping def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): if len(input_ids.shape) != 0: if tensor is not None: return [self.set(id, t) for id, t in zip(input_ids, tensor)] else: return [self.set(id) for id in input_ids] if tensor is None: tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) if tensor.shape[-1] != self.embedding_dim: raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") id = self.mapping[input_ids] if id == -1: id = len(self.params) self.mapping[input_ids] = id self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) self.params[id] = tensor def unset(self, input_ids: torch.LongTensor): self.mapping[input_ids] = -1