diff options
Diffstat (limited to 'models/sparse.py')
-rw-r--r-- | models/sparse.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/models/sparse.py b/models/sparse.py index e5897c9..55c9837 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
@@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding): | |||
89 | 89 | ||
90 | return weights | 90 | return weights |
91 | 91 | ||
92 | def persist(self): | 92 | def persist(self, clear=False): |
93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
94 | self.trainable_ids[:] = -1 | 94 | |
95 | self.trainable = nn.ParameterList() | 95 | if clear: |
96 | self.trainable_ids[:] = -1 | ||
97 | self.trainable = nn.ParameterList() | ||
98 | else: | ||
99 | for param in self.trainable: | ||
100 | param.zero_() | ||
96 | 101 | ||
97 | def reset_parameters(self): | 102 | def reset_parameters(self): |
98 | nn.Embedding.reset_parameters(self) | 103 | nn.Embedding.reset_parameters(self) |