from typing import Optional import torch import torch.nn as nn class SparseEmbedding(nn.Embedding): def __init__( self, num_embeddings: int, embedding_dim: int, alpha: int = 1, dropout: float = 0.0, **kwargs ): nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) self.register_buffer( "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1 ) self.trainable = nn.ParameterList() self.scaling = alpha self.dropout_p = dropout self.weight.requires_grad = False if dropout > 0.0: self.dropout = nn.Dropout(p=dropout) else: self.dropout = nn.Identity() self.reset_parameters() def new_resized( self, new_num_embeddings: int, initializer_factor: Optional[float] = None ): n = min(self.num_embeddings, new_num_embeddings) new_emb = SparseEmbedding( new_num_embeddings, self.embedding_dim, self.scaling, self.dropout_p, device=self.weight.device, dtype=self.weight.dtype, ) if initializer_factor is not None: new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) else: nn.init.zeros_(new_emb.weight.data) new_emb.weight.data[:n, :] = self.weight.data[:n, :] for param in self.trainable: new_emb.trainable.append(param) new_emb.trainable_ids[:n] = self.trainable_ids[:n] return new_emb def mark_trainable(self, input_ids: torch.LongTensor): trainable_ids = self.trainable_ids[input_ids] new_ids = input_ids[trainable_ids == -1] if new_ids.shape[0] == 0: return n1 = len(self.trainable) n2 = n1 + new_ids.shape[0] self.trainable_ids[new_ids] = torch.arange(n1, n2) for _ in new_ids: self.trainable.append(self.weight.new_zeros(self.embedding_dim)) def get_weights(self, input_ids: torch.Tensor): original_shape = input_ids.shape if len(input_ids.shape) != 1: input_ids = input_ids.view(input_ids.shape[0] * input_ids.shape[1]) weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) trainable_ids = self.trainable_ids[input_ids] mask = ~(trainable_ids == -1) elems = [self.trainable[id] for id in trainable_ids[mask]] if len(elems) != 0: w = self.dropout(torch.stack(elems)) * self.scaling weights[mask] = w.to(dtype=weights.dtype) if len(original_shape) != 1: weights = weights.view(original_shape[0], original_shape[1], -1) return weights def persist(self, clear=False): self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) if clear: self.trainable_ids[:] = -1 self.trainable = nn.ParameterList() else: for param in self.trainable: param.zero_() def reset_parameters(self): nn.Embedding.reset_parameters(self) if hasattr(self, "trainable"): self.trainable_ids[:] = -1 self.trainable = nn.ParameterList() def train(self, mode: bool = True): nn.Embedding.train(self, mode) self.trainable.train(mode) def eval(self): nn.Embedding.eval(self) self.trainable.eval() def forward(self, input_ids: torch.LongTensor): result = nn.Embedding.forward(self, input_ids) result += self.get_weights(input_ids) return result