summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/sparse.py')
-rw-r--r--models/sparse.py66
1 files changed, 0 insertions, 66 deletions
diff --git a/models/sparse.py b/models/sparse.py
deleted file mode 100644
index 07b3413..0000000
--- a/models/sparse.py
+++ /dev/null
@@ -1,66 +0,0 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6
7class PseudoSparseEmbedding(nn.Module):
8 def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32):
9 super().__init__()
10
11 self.embedding_dim = embedding_dim
12 self.dtype = dtype
13 self.params = nn.ParameterList()
14
15 if dropout_p > 0.0:
16 self.dropout = nn.Dropout(p=dropout_p)
17 else:
18 self.dropout = nn.Identity()
19
20 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long))
21
22 def forward(self, input_ids: torch.LongTensor):
23 input_ids = input_ids.to(self.mapping.device)
24 ids = self.mapping[input_ids]
25 mask = ~(ids == -1)
26
27 if torch.all(~mask):
28 embs = None
29 else:
30 embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]]))
31
32 return embs, mask
33
34 def resize(self, new_num_embeddings: int):
35 old_num_embeddings = self.mapping.shape[0]
36 n = min(old_num_embeddings, new_num_embeddings)
37
38 new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1
39 new_mapping[:n] = self.mapping[:n]
40
41 self.mapping = new_mapping
42
43 def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None):
44 if len(input_ids.shape) != 0:
45 if tensor is not None:
46 return [self.set(id, t) for id, t in zip(input_ids, tensor)]
47 else:
48 return [self.set(id) for id in input_ids]
49
50 if tensor is None:
51 tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
52
53 if tensor.shape[-1] != self.embedding_dim:
54 raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]")
55
56 id = self.mapping[input_ids]
57
58 if id == -1:
59 id = len(self.params)
60 self.mapping[input_ids] = id
61 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype))
62
63 self.params[id] = tensor
64
65 def unset(self, input_ids: torch.LongTensor):
66 self.mapping[input_ids] = -1