From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- models/sparse.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'models/sparse.py') diff --git a/models/sparse.py b/models/sparse.py index bd45696..e5897c9 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -15,21 +15,25 @@ class SparseEmbedding(nn.Embedding): ): nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) - self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) + self.register_buffer( + "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1 + ) self.trainable = nn.ParameterList() self.scaling = alpha self.dropout_p = dropout self.weight.requires_grad = False - if dropout > 0.: + if dropout > 0.0: self.dropout = nn.Dropout(p=dropout) else: self.dropout = nn.Identity() self.reset_parameters() - def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): + def new_resized( + self, new_num_embeddings: int, initializer_factor: Optional[float] = None + ): n = min(self.num_embeddings, new_num_embeddings) new_emb = SparseEmbedding( @@ -38,7 +42,7 @@ class SparseEmbedding(nn.Embedding): self.scaling, self.dropout_p, device=self.weight.device, - dtype=self.weight.dtype + dtype=self.weight.dtype, ) if initializer_factor is not None: new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) -- cgit v1.2.3-54-g00ecf