From 12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 21:00:29 +0200 Subject: Update --- models/sparse.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'models/sparse.py') diff --git a/models/sparse.py b/models/sparse.py index e5897c9..55c9837 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding): return weights - def persist(self): + def persist(self, clear=False): self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.trainable_ids[:] = -1 - self.trainable = nn.ParameterList() + + if clear: + self.trainable_ids[:] = -1 + self.trainable = nn.ParameterList() + else: + for param in self.trainable: + param.zero_() def reset_parameters(self): nn.Embedding.reset_parameters(self) -- cgit v1.2.3-54-g00ecf