summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
blob: 55c983752e16c847d0022218e8766b06ec984168 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Optional

import torch
import torch.nn as nn


class SparseEmbedding(nn.Embedding):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        alpha: int = 1,
        dropout: float = 0.0,
        **kwargs
    ):
        nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)

        self.register_buffer(
            "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1
        )

        self.trainable = nn.ParameterList()
        self.scaling = alpha
        self.dropout_p = dropout
        self.weight.requires_grad = False

        if dropout > 0.0:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = nn.Identity()

        self.reset_parameters()

    def new_resized(
        self, new_num_embeddings: int, initializer_factor: Optional[float] = None
    ):
        n = min(self.num_embeddings, new_num_embeddings)

        new_emb = SparseEmbedding(
            new_num_embeddings,
            self.embedding_dim,
            self.scaling,
            self.dropout_p,
            device=self.weight.device,
            dtype=self.weight.dtype,
        )
        if initializer_factor is not None:
            new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
        else:
            nn.init.zeros_(new_emb.weight.data)
        new_emb.weight.data[:n, :] = self.weight.data[:n, :]
        for param in self.trainable:
            new_emb.trainable.append(param)
        new_emb.trainable_ids[:n] = self.trainable_ids[:n]

        return new_emb

    def mark_trainable(self, input_ids: torch.LongTensor):
        trainable_ids = self.trainable_ids[input_ids]
        new_ids = input_ids[trainable_ids == -1]

        if new_ids.shape[0] == 0:
            return

        n1 = len(self.trainable)
        n2 = n1 + new_ids.shape[0]
        self.trainable_ids[new_ids] = torch.arange(n1, n2)
        for _ in new_ids:
            self.trainable.append(self.weight.new_zeros(self.embedding_dim))

    def get_weights(self, input_ids: torch.Tensor):
        original_shape = input_ids.shape

        if len(input_ids.shape) != 1:
            input_ids = input_ids.view(input_ids.shape[0] * input_ids.shape[1])

        weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))

        trainable_ids = self.trainable_ids[input_ids]
        mask = ~(trainable_ids == -1)
        elems = [self.trainable[id] for id in trainable_ids[mask]]

        if len(elems) != 0:
            w = self.dropout(torch.stack(elems)) * self.scaling
            weights[mask] = w.to(dtype=weights.dtype)

        if len(original_shape) != 1:
            weights = weights.view(original_shape[0], original_shape[1], -1)

        return weights

    def persist(self, clear=False):
        self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))

        if clear:
            self.trainable_ids[:] = -1
            self.trainable = nn.ParameterList()
        else:
            for param in self.trainable:
                param.zero_()

    def reset_parameters(self):
        nn.Embedding.reset_parameters(self)
        if hasattr(self, "trainable"):
            self.trainable_ids[:] = -1
            self.trainable = nn.ParameterList()

    def train(self, mode: bool = True):
        nn.Embedding.train(self, mode)
        self.trainable.train(mode)

    def eval(self):
        nn.Embedding.eval(self)
        self.trainable.eval()

    def forward(self, input_ids: torch.LongTensor):
        result = nn.Embedding.forward(self, input_ids)
        result += self.get_weights(input_ids)
        return result