summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/sparse.py')
-rw-r--r--models/sparse.py110
1 files changed, 110 insertions, 0 deletions
diff --git a/models/sparse.py b/models/sparse.py
new file mode 100644
index 0000000..bd45696
--- /dev/null
+++ b/models/sparse.py
@@ -0,0 +1,110 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6
7class SparseEmbedding(nn.Embedding):
8 def __init__(
9 self,
10 num_embeddings: int,
11 embedding_dim: int,
12 alpha: int = 1,
13 dropout: float = 0.0,
14 **kwargs
15 ):
16 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
17
18 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1)
19
20 self.trainable = nn.ParameterList()
21 self.scaling = alpha
22 self.dropout_p = dropout
23 self.weight.requires_grad = False
24
25 if dropout > 0.:
26 self.dropout = nn.Dropout(p=dropout)
27 else:
28 self.dropout = nn.Identity()
29
30 self.reset_parameters()
31
32 def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
33 n = min(self.num_embeddings, new_num_embeddings)
34
35 new_emb = SparseEmbedding(
36 new_num_embeddings,
37 self.embedding_dim,
38 self.scaling,
39 self.dropout_p,
40 device=self.weight.device,
41 dtype=self.weight.dtype
42 )
43 if initializer_factor is not None:
44 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
45 else:
46 nn.init.zeros_(new_emb.weight.data)
47 new_emb.weight.data[:n, :] = self.weight.data[:n, :]
48 for param in self.trainable:
49 new_emb.trainable.append(param)
50 new_emb.trainable_ids[:n] = self.trainable_ids[:n]
51
52 return new_emb
53
54 def mark_trainable(self, input_ids: torch.LongTensor):
55 trainable_ids = self.trainable_ids[input_ids]
56 new_ids = input_ids[trainable_ids == -1]
57
58 if new_ids.shape[0] == 0:
59 return
60
61 n1 = len(self.trainable)
62 n2 = n1 + new_ids.shape[0]
63 self.trainable_ids[new_ids] = torch.arange(n1, n2)
64 for _ in new_ids:
65 self.trainable.append(self.weight.new_zeros(self.embedding_dim))
66
67 def get_weights(self, input_ids: torch.Tensor):
68 original_shape = input_ids.shape
69
70 if len(input_ids.shape) != 1:
71 input_ids = input_ids.view(input_ids.shape[0] * input_ids.shape[1])
72
73 weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))
74
75 trainable_ids = self.trainable_ids[input_ids]
76 mask = ~(trainable_ids == -1)
77 elems = [self.trainable[id] for id in trainable_ids[mask]]
78
79 if len(elems) != 0:
80 w = self.dropout(torch.stack(elems)) * self.scaling
81 weights[mask] = w.to(dtype=weights.dtype)
82
83 if len(original_shape) != 1:
84 weights = weights.view(original_shape[0], original_shape[1], -1)
85
86 return weights
87
88 def persist(self):
89 self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0]))
90 self.trainable_ids[:] = -1
91 self.trainable = nn.ParameterList()
92
93 def reset_parameters(self):
94 nn.Embedding.reset_parameters(self)
95 if hasattr(self, "trainable"):
96 self.trainable_ids[:] = -1
97 self.trainable = nn.ParameterList()
98
99 def train(self, mode: bool = True):
100 nn.Embedding.train(self, mode)
101 self.trainable.train(mode)
102
103 def eval(self):
104 nn.Embedding.eval(self)
105 self.trainable.eval()
106
107 def forward(self, input_ids: torch.LongTensor):
108 result = nn.Embedding.forward(self, input_ids)
109 result += self.get_weights(input_ids)
110 return result