summaryrefslogtreecommitdiffstats
path: root/models/sparse.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/sparse.py')
-rw-r--r--models/sparse.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/models/sparse.py b/models/sparse.py
index bd45696..e5897c9 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -15,21 +15,25 @@ class SparseEmbedding(nn.Embedding):
15 ): 15 ):
16 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) 16 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
17 17
18 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) 18 self.register_buffer(
19 "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1
20 )
19 21
20 self.trainable = nn.ParameterList() 22 self.trainable = nn.ParameterList()
21 self.scaling = alpha 23 self.scaling = alpha
22 self.dropout_p = dropout 24 self.dropout_p = dropout
23 self.weight.requires_grad = False 25 self.weight.requires_grad = False
24 26
25 if dropout > 0.: 27 if dropout > 0.0:
26 self.dropout = nn.Dropout(p=dropout) 28 self.dropout = nn.Dropout(p=dropout)
27 else: 29 else:
28 self.dropout = nn.Identity() 30 self.dropout = nn.Identity()
29 31
30 self.reset_parameters() 32 self.reset_parameters()
31 33
32 def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): 34 def new_resized(
35 self, new_num_embeddings: int, initializer_factor: Optional[float] = None
36 ):
33 n = min(self.num_embeddings, new_num_embeddings) 37 n = min(self.num_embeddings, new_num_embeddings)
34 38
35 new_emb = SparseEmbedding( 39 new_emb = SparseEmbedding(
@@ -38,7 +42,7 @@ class SparseEmbedding(nn.Embedding):
38 self.scaling, 42 self.scaling,
39 self.dropout_p, 43 self.dropout_p,
40 device=self.weight.device, 44 device=self.weight.device,
41 dtype=self.weight.dtype 45 dtype=self.weight.dtype,
42 ) 46 )
43 if initializer_factor is not None: 47 if initializer_factor is not None:
44 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) 48 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)