summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/sparse.py')
-rw-r--r--models/sparse.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/models/sparse.py b/models/sparse.py
index d706db5..bcb2897 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -5,22 +5,29 @@ import torch.nn as nn
5 5
6 6
7class PseudoSparseEmbedding(nn.Module): 7class PseudoSparseEmbedding(nn.Module):
8 def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): 8 def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32):
9 super().__init__() 9 super().__init__()
10 10
11 self.embedding_dim = embedding_dim 11 self.embedding_dim = embedding_dim
12 self.dtype = dtype 12 self.dtype = dtype
13 self.params = nn.ParameterList() 13 self.params = nn.ParameterList()
14
15 if dropout_p > 0.0:
16 self.dropout = nn.Dropout(p=dropout_p)
17 else:
18 self.dropout = lambda x: x
19
14 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) 20 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long))
15 21
16 def forward(self, input_ids: torch.LongTensor): 22 def forward(self, input_ids: torch.LongTensor):
17 ids = self.mapping[input_ids.to(self.mapping.device)] 23 input_ids = input_ids.to(self.mapping.device)
24 ids = self.mapping[input_ids]
18 mask = ~(ids == -1) 25 mask = ~(ids == -1)
19 26
20 if torch.all(~mask): 27 if torch.all(~mask):
21 embs = None 28 embs = None
22 else: 29 else:
23 embs = torch.stack([self.params[id] for id in ids[mask]]) 30 embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]]))
24 31
25 return embs, mask 32 return embs, mask
26 33