diff options
Diffstat (limited to 'models/sparse.py')
-rw-r--r-- | models/sparse.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/models/sparse.py b/models/sparse.py index bd45696..e5897c9 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
@@ -15,21 +15,25 @@ class SparseEmbedding(nn.Embedding): | |||
15 | ): | 15 | ): |
16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | 16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) |
17 | 17 | ||
18 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | 18 | self.register_buffer( |
19 | "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1 | ||
20 | ) | ||
19 | 21 | ||
20 | self.trainable = nn.ParameterList() | 22 | self.trainable = nn.ParameterList() |
21 | self.scaling = alpha | 23 | self.scaling = alpha |
22 | self.dropout_p = dropout | 24 | self.dropout_p = dropout |
23 | self.weight.requires_grad = False | 25 | self.weight.requires_grad = False |
24 | 26 | ||
25 | if dropout > 0.: | 27 | if dropout > 0.0: |
26 | self.dropout = nn.Dropout(p=dropout) | 28 | self.dropout = nn.Dropout(p=dropout) |
27 | else: | 29 | else: |
28 | self.dropout = nn.Identity() | 30 | self.dropout = nn.Identity() |
29 | 31 | ||
30 | self.reset_parameters() | 32 | self.reset_parameters() |
31 | 33 | ||
32 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | 34 | def new_resized( |
35 | self, new_num_embeddings: int, initializer_factor: Optional[float] = None | ||
36 | ): | ||
33 | n = min(self.num_embeddings, new_num_embeddings) | 37 | n = min(self.num_embeddings, new_num_embeddings) |
34 | 38 | ||
35 | new_emb = SparseEmbedding( | 39 | new_emb = SparseEmbedding( |
@@ -38,7 +42,7 @@ class SparseEmbedding(nn.Embedding): | |||
38 | self.scaling, | 42 | self.scaling, |
39 | self.dropout_p, | 43 | self.dropout_p, |
40 | device=self.weight.device, | 44 | device=self.weight.device, |
41 | dtype=self.weight.dtype | 45 | dtype=self.weight.dtype, |
42 | ) | 46 | ) |
43 | if initializer_factor is not None: | 47 | if initializer_factor is not None: |
44 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | 48 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) |