diff options
Diffstat (limited to 'models/sparse.py')
-rw-r--r-- | models/sparse.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/models/sparse.py b/models/sparse.py index 0b15454..8910316 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
@@ -13,10 +13,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
13 | self.params = nn.ParameterList() | 13 | self.params = nn.ParameterList() |
14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) | 14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) |
15 | 15 | ||
16 | def forward(self, input_ids: Optional[torch.LongTensor] = None): | 16 | def forward(self, input_ids: torch.LongTensor): |
17 | if input_ids is None: | ||
18 | input_ids = torch.arange(self.mapping.shape[0]) | ||
19 | |||
20 | ids = self.mapping[input_ids.to(self.mapping.device)] | 17 | ids = self.mapping[input_ids.to(self.mapping.device)] |
21 | mask = ~(ids == -1) | 18 | mask = ~(ids == -1) |
22 | 19 | ||
@@ -43,6 +40,12 @@ class PseudoSparseEmbedding(nn.Module): | |||
43 | else: | 40 | else: |
44 | return [self.set(id) for id in input_ids] | 41 | return [self.set(id) for id in input_ids] |
45 | 42 | ||
43 | if tensor is None: | ||
44 | tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
45 | |||
46 | if tensor.shape[-1] != self.embedding_dim: | ||
47 | raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") | ||
48 | |||
46 | id = self.mapping[input_ids] | 49 | id = self.mapping[input_ids] |
47 | 50 | ||
48 | if id == -1: | 51 | if id == -1: |
@@ -50,8 +53,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
50 | self.mapping[input_ids] = id | 53 | self.mapping[input_ids] = id |
51 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) | 54 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) |
52 | 55 | ||
53 | self.params[id] = tensor if tensor is not None else torch.zeros( | 56 | self.params[id] = tensor |
54 | self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
55 | 57 | ||
56 | def unset(self, input_ids: torch.LongTensor): | 58 | def unset(self, input_ids: torch.LongTensor): |
57 | self.mapping[input_ids] = -1 | 59 | self.mapping[input_ids] = -1 |