diff options
Diffstat (limited to 'models/sparse.py')
| -rw-r--r-- | models/sparse.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/models/sparse.py b/models/sparse.py index 8910316..d706db5 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -11,7 +11,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
| 11 | self.embedding_dim = embedding_dim | 11 | self.embedding_dim = embedding_dim |
| 12 | self.dtype = dtype | 12 | self.dtype = dtype |
| 13 | self.params = nn.ParameterList() | 13 | self.params = nn.ParameterList() |
| 14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) | 14 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) |
| 15 | 15 | ||
| 16 | def forward(self, input_ids: torch.LongTensor): | 16 | def forward(self, input_ids: torch.LongTensor): |
| 17 | ids = self.mapping[input_ids.to(self.mapping.device)] | 17 | ids = self.mapping[input_ids.to(self.mapping.device)] |
