summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
blob: 89103160fcac96930c3e11b0f90a93f157cec9eb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import Optional

import torch
import torch.nn as nn


class PseudoSparseEmbedding(nn.Module):
    def __init__(self, embedding_dim: int, device=None, dtype=torch.float32):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.dtype = dtype
        self.params = nn.ParameterList()
        self.mapping = torch.zeros(0, device=device, dtype=torch.long)

    def forward(self, input_ids: torch.LongTensor):
        ids = self.mapping[input_ids.to(self.mapping.device)]
        mask = ~(ids == -1)

        if torch.all(~mask):
            embs = None
        else:
            embs = torch.stack([self.params[id] for id in ids[mask]])

        return embs, mask

    def resize(self, new_num_embeddings: int):
        old_num_embeddings = self.mapping.shape[0]
        n = min(old_num_embeddings, new_num_embeddings)

        new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1
        new_mapping[:n] = self.mapping[:n]

        self.mapping = new_mapping

    def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None):
        if len(input_ids.shape) != 0:
            if tensor is not None:
                return [self.set(id, t) for id, t in zip(input_ids, tensor)]
            else:
                return [self.set(id) for id in input_ids]

        if tensor is None:
            tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)

        if tensor.shape[-1] != self.embedding_dim:
            raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]")

        id = self.mapping[input_ids]

        if id == -1:
            id = len(self.params)
            self.mapping[input_ids] = id
            self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype))

        self.params[id] = tensor

    def unset(self, input_ids: torch.LongTensor):
        self.mapping[input_ids] = -1