diff options
Diffstat (limited to 'models/sparse.py')
| -rw-r--r-- | models/sparse.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/models/sparse.py b/models/sparse.py index e5897c9..55c9837 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding): | |||
| 89 | 89 | ||
| 90 | return weights | 90 | return weights |
| 91 | 91 | ||
| 92 | def persist(self): | 92 | def persist(self, clear=False): |
| 93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 93 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 94 | self.trainable_ids[:] = -1 | 94 | |
| 95 | self.trainable = nn.ParameterList() | 95 | if clear: |
| 96 | self.trainable_ids[:] = -1 | ||
| 97 | self.trainable = nn.ParameterList() | ||
| 98 | else: | ||
| 99 | for param in self.trainable: | ||
| 100 | param.zero_() | ||
| 96 | 101 | ||
| 97 | def reset_parameters(self): | 102 | def reset_parameters(self): |
| 98 | nn.Embedding.reset_parameters(self) | 103 | nn.Embedding.reset_parameters(self) |
