diff options
Diffstat (limited to 'models/sparse.py')
| -rw-r--r-- | models/sparse.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/models/sparse.py b/models/sparse.py index bd45696..e5897c9 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -15,21 +15,25 @@ class SparseEmbedding(nn.Embedding): | |||
| 15 | ): | 15 | ): |
| 16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | 16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) |
| 17 | 17 | ||
| 18 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | 18 | self.register_buffer( |
| 19 | "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1 | ||
| 20 | ) | ||
| 19 | 21 | ||
| 20 | self.trainable = nn.ParameterList() | 22 | self.trainable = nn.ParameterList() |
| 21 | self.scaling = alpha | 23 | self.scaling = alpha |
| 22 | self.dropout_p = dropout | 24 | self.dropout_p = dropout |
| 23 | self.weight.requires_grad = False | 25 | self.weight.requires_grad = False |
| 24 | 26 | ||
| 25 | if dropout > 0.: | 27 | if dropout > 0.0: |
| 26 | self.dropout = nn.Dropout(p=dropout) | 28 | self.dropout = nn.Dropout(p=dropout) |
| 27 | else: | 29 | else: |
| 28 | self.dropout = nn.Identity() | 30 | self.dropout = nn.Identity() |
| 29 | 31 | ||
| 30 | self.reset_parameters() | 32 | self.reset_parameters() |
| 31 | 33 | ||
| 32 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | 34 | def new_resized( |
| 35 | self, new_num_embeddings: int, initializer_factor: Optional[float] = None | ||
| 36 | ): | ||
| 33 | n = min(self.num_embeddings, new_num_embeddings) | 37 | n = min(self.num_embeddings, new_num_embeddings) |
| 34 | 38 | ||
| 35 | new_emb = SparseEmbedding( | 39 | new_emb = SparseEmbedding( |
| @@ -38,7 +42,7 @@ class SparseEmbedding(nn.Embedding): | |||
| 38 | self.scaling, | 42 | self.scaling, |
| 39 | self.dropout_p, | 43 | self.dropout_p, |
| 40 | device=self.weight.device, | 44 | device=self.weight.device, |
| 41 | dtype=self.weight.dtype | 45 | dtype=self.weight.dtype, |
| 42 | ) | 46 | ) |
| 43 | if initializer_factor is not None: | 47 | if initializer_factor is not None: |
| 44 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | 48 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) |
