summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
blob: e5897c970c512cfea4c3b103d0d5b52c10a1bb76 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Optional

import torch
import torch.nn as nn


class SparseEmbedding(nn.Embedding):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        alpha: int = 1,
        dropout: float = 0.0,
        **kwargs
    ):
        nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)

        self.register_buffer(
            "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1
        )

        self.trainable = nn.ParameterList()
        self.scaling = alpha
        self.dropout_p = dropout
        self.weight.requires_grad = False

        if dropout > 0.0:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = nn.Identity()

        self.reset_parameters()

    def new_resized(
        self, new_num_embeddings: int, initializer_factor: Optional[float] = None
    ):
        n = min(self.num_embeddings, new_num_embeddings)

        new_emb = SparseEmbedding(
            new_num_embeddings,
            self.embedding_dim,
            self.scaling,
            self.dropout_p,
            device=self.weight.device,
            dtype=self.weight.dtype,
        )
        if initializer_factor is not None:
            new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
        else:
            nn.init.zeros_(new_emb.weight.data)
        new_emb.weight.data[:n, :] = self.weight.data[:n, :]
        for param in self.trainable:
            new_emb.trainable.append(param)
        new_emb.trainable_ids[:n] = self.trainable_ids[:n]

        return new_emb

    def mark_trainable(self, input_ids: torch.LongTensor):
        trainable_ids = self.trainable_ids[input_ids]
        new_ids = input_ids[trainable_ids == -1]

        if new_ids.shape[0] == 0:
            return

        n1 = len(self.trainable)
        n2 = n1 + new_ids.shape[0]
        self.trainable_ids[new_ids] = torch.arange(n1, n2)
        for _ in new_ids:
            self.trainable.append(self.weight.new_zeros(self.embedding_dim))

    def get_weights(self, input_ids: torch.Tensor):
        original_shape = input_ids.shape

        if len(input_ids.shape) != 1:
            input_ids = input_ids.view(input_ids.shape[0] * input_ids.shape[1])

        weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))

        trainable_ids = self.trainable_ids[input_ids]
        mask = ~(trainable_ids == -1)
        elems = [self.trainable[id] for id in trainable_ids[mask]]

        if len(elems) != 0:
            w = self.dropout(torch.stack(elems)) * self.scaling
            weights[mask] = w.to(dtype=weights.dtype)

        if len(original_shape) != 1:
            weights = weights.view(original_shape[0], original_shape[1], -1)

        return weights

    def persist(self):
        self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
        self.trainable_ids[:] = -1
        self.trainable = nn.ParameterList()

    def reset_parameters(self):
        nn.Embedding.reset_parameters(self)
        if hasattr(self, "trainable"):
            self.trainable_ids[:] = -1
            self.trainable = nn.ParameterList()

    def train(self, mode: bool = True):
        nn.Embedding.train(self, mode)
        self.trainable.train(mode)

    def eval(self):
        nn.Embedding.eval(self)
        self.trainable.eval()

    def forward(self, input_ids: torch.LongTensor):
        result = nn.Embedding.forward(self, input_ids)
        result += self.get_weights(input_ids)
        return result