from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F class LoraLayer(): def __init__( self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool, ): self.r = r self.lora_alpha = lora_alpha self.lora_dropout_p = lora_dropout if lora_dropout > 0.: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = nn.Identity() self.merged = False self.merge_weights = merge_weights class LoraEmbedding(nn.Embedding, LoraLayer): def __init__( self, num_embeddings: int, embedding_dim: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, merge_weights: bool = True, **kwargs ): nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) LoraLayer.__init__( self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights ) self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) self.trainable_ids -= 1 if r > 0: self.lora_A = nn.ParameterList() self.lora_B = nn.Linear(r, embedding_dim, bias=False) self.scaling = self.lora_alpha / self.r self.weight.requires_grad = False self.reset_parameters() def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): n = min(self.num_embeddings, new_num_embeddings) new_emb = LoraEmbedding( new_num_embeddings, self.embedding_dim, self.r, self.lora_alpha, self.lora_dropout_p, device=self.weight.device, dtype=self.weight.dtype ) if initializer_factor is not None: new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) else: nn.init.zeros_(new_emb.weight.data) new_emb.weight.data[:n, :] = self.weight.data[:n, :] new_emb.lora_A = self.lora_A new_emb.lora_B = self.lora_B new_emb.trainable_ids[:n] = self.trainable_ids[:n] return new_emb def mark_trainable(self, input_ids): trainable_ids = self.trainable_ids[input_ids] new_ids = input_ids[trainable_ids == -1] if new_ids.shape[0] == 0: return n1 = len(self.lora_A) n2 = n1 + new_ids.shape[0] self.trainable_ids[new_ids] = torch.arange(n1, n2) for _ in new_ids: self.lora_A.append(self.weight.new_zeros(self.r)) def persist(self): if self.r > 0: weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) if weights is not None: self.weight[mask].data += weights self.trainable_ids[:] = -1 self.lora_A = nn.ParameterList() def get_weights(self, input_ids: torch.Tensor): trainable_ids = self.trainable_ids[input_ids] mask = ~(trainable_ids == -1) trainable_ids = trainable_ids[mask] elems = [self.lora_A[id] for id in trainable_ids] if len(elems) == 0: return None, mask weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling return weights, mask def reset_parameters(self): nn.Embedding.reset_parameters(self) if hasattr(self, 'lora_A'): self.lora_A = nn.ParameterList() nn.init.zeros_(self.lora_B.weight) def train(self, mode: bool = True): nn.Embedding.train(self, mode) if self.merge_weights and self.merged: if self.r > 0: weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) if weights is not None: self.weight[mask].data -= weights self.merged = False def eval(self): nn.Embedding.eval(self) if self.merge_weights and not self.merged: if self.r > 0: weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) if weights is not None: self.weight[mask].data += weights self.merged = True def forward(self, input_ids: torch.Tensor): result = nn.Embedding.forward(self, input_ids) if self.r > 0 and not self.merged: weights, mask = self.get_weights(input_ids) if weights is not None: result[mask] += weights return result