From 30b557c8e1f03b4748ac3efca599ff51d66561cb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Apr 2023 07:30:43 +0200 Subject: TI: Bring back old embedding decay --- models/sparse.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'models/sparse.py') diff --git a/models/sparse.py b/models/sparse.py index 0b15454..8910316 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -13,10 +13,7 @@ class PseudoSparseEmbedding(nn.Module): self.params = nn.ParameterList() self.mapping = torch.zeros(0, device=device, dtype=torch.long) - def forward(self, input_ids: Optional[torch.LongTensor] = None): - if input_ids is None: - input_ids = torch.arange(self.mapping.shape[0]) - + def forward(self, input_ids: torch.LongTensor): ids = self.mapping[input_ids.to(self.mapping.device)] mask = ~(ids == -1) @@ -43,6 +40,12 @@ class PseudoSparseEmbedding(nn.Module): else: return [self.set(id) for id in input_ids] + if tensor is None: + tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) + + if tensor.shape[-1] != self.embedding_dim: + raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") + id = self.mapping[input_ids] if id == -1: @@ -50,8 +53,7 @@ class PseudoSparseEmbedding(nn.Module): self.mapping[input_ids] = id self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) - self.params[id] = tensor if tensor is not None else torch.zeros( - self.embedding_dim, device=self.mapping.device, dtype=self.dtype) + self.params[id] = tensor def unset(self, input_ids: torch.LongTensor): self.mapping[input_ids] = -1 -- cgit v1.2.3-54-g00ecf