from typing import Optional import torch import torch.nn as nn class PseudoSparseEmbedding(nn.Module): def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): super().__init__() self.embedding_dim = embedding_dim self.dtype = dtype self.params = nn.ParameterList() if dropout_p > 0.0: self.dropout = nn.Dropout(p=dropout_p) else: self.dropout = lambda x: x self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) def forward(self, input_ids: torch.LongTensor): input_ids = input_ids.to(self.mapping.device) ids = self.mapping[input_ids] mask = ~(ids == -1) if torch.all(~mask): embs = None else: embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) return embs, mask def resize(self, new_num_embeddings: int): old_num_embeddings = self.mapping.shape[0] n = min(old_num_embeddings, new_num_embeddings) new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 new_mapping[:n] = self.mapping[:n] self.mapping = new_mapping def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): if len(input_ids.shape) != 0: if tensor is not None: return [self.set(id, t) for id, t in zip(input_ids, tensor)] else: return [self.set(id) for id in input_ids] if tensor is None: tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) if tensor.shape[-1] != self.embedding_dim: raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") id = self.mapping[input_ids] if id == -1: id = len(self.params) self.mapping[input_ids] = id self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) self.params[id] = tensor def unset(self, input_ids: torch.LongTensor): self.mapping[input_ids] = -1