diff options
Diffstat (limited to 'models/sparse.py')
| -rw-r--r-- | models/sparse.py | 12 | 
1 files changed, 8 insertions, 4 deletions
diff --git a/models/sparse.py b/models/sparse.py index bd45696..e5897c9 100644 --- a/models/sparse.py +++ b/models/sparse.py  | |||
| @@ -15,21 +15,25 @@ class SparseEmbedding(nn.Embedding): | |||
| 15 | ): | 15 | ): | 
| 16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | 16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | 
| 17 | 17 | ||
| 18 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | 18 | self.register_buffer( | 
| 19 | "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1 | ||
| 20 | ) | ||
| 19 | 21 | ||
| 20 | self.trainable = nn.ParameterList() | 22 | self.trainable = nn.ParameterList() | 
| 21 | self.scaling = alpha | 23 | self.scaling = alpha | 
| 22 | self.dropout_p = dropout | 24 | self.dropout_p = dropout | 
| 23 | self.weight.requires_grad = False | 25 | self.weight.requires_grad = False | 
| 24 | 26 | ||
| 25 | if dropout > 0.: | 27 | if dropout > 0.0: | 
| 26 | self.dropout = nn.Dropout(p=dropout) | 28 | self.dropout = nn.Dropout(p=dropout) | 
| 27 | else: | 29 | else: | 
| 28 | self.dropout = nn.Identity() | 30 | self.dropout = nn.Identity() | 
| 29 | 31 | ||
| 30 | self.reset_parameters() | 32 | self.reset_parameters() | 
| 31 | 33 | ||
| 32 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | 34 | def new_resized( | 
| 35 | self, new_num_embeddings: int, initializer_factor: Optional[float] = None | ||
| 36 | ): | ||
| 33 | n = min(self.num_embeddings, new_num_embeddings) | 37 | n = min(self.num_embeddings, new_num_embeddings) | 
| 34 | 38 | ||
| 35 | new_emb = SparseEmbedding( | 39 | new_emb = SparseEmbedding( | 
| @@ -38,7 +42,7 @@ class SparseEmbedding(nn.Embedding): | |||
| 38 | self.scaling, | 42 | self.scaling, | 
| 39 | self.dropout_p, | 43 | self.dropout_p, | 
| 40 | device=self.weight.device, | 44 | device=self.weight.device, | 
| 41 | dtype=self.weight.dtype | 45 | dtype=self.weight.dtype, | 
| 42 | ) | 46 | ) | 
| 43 | if initializer_factor is not None: | 47 | if initializer_factor is not None: | 
| 44 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | 48 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | 
