From 1a0161f345191d78a19eec829f9d73b2c2c72f94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 09:44:12 +0200 Subject: Update --- models/lora.py | 59 ++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 22 deletions(-) (limited to 'models/lora.py') diff --git a/models/lora.py b/models/lora.py index 89c4b2e..b7fa58f 100644 --- a/models/lora.py +++ b/models/lora.py @@ -46,8 +46,8 @@ class LoraEmbedding(nn.Embedding, LoraLayer): self.trainable_ids -= 1 if r > 0: - self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) - self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) + self.lora_A = nn.ParameterList() + self.lora_B = nn.Linear(r, embedding_dim, bias=False) self.scaling = self.lora_alpha / self.r self.weight.requires_grad = False @@ -83,49 +83,64 @@ class LoraEmbedding(nn.Embedding, LoraLayer): if new_ids.shape[0] == 0: return - n1 = self.lora_A.shape[1] + n1 = len(self.lora_A) n2 = n1 + new_ids.shape[0] self.trainable_ids[new_ids] = torch.arange(n1, n2) + for _ in new_ids: + self.lora_A.append(self.weight.new_zeros(self.r)) + + def persist(self): + if self.r > 0: + weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) + if weights is not None: + self.weight[mask].data += weights + self.trainable_ids[:] = -1 + self.lora_A = nn.ParameterList() + + def get_weights(self, input_ids: torch.Tensor): + trainable_ids = self.trainable_ids[input_ids] + mask = ~(trainable_ids == -1) + trainable_ids = trainable_ids[mask] + + elems = [self.lora_A[id] for id in trainable_ids] + + if len(elems) == 0: + return None, mask + + weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling - lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2))) - self.lora_A = lora_A + return weights, mask def reset_parameters(self): nn.Embedding.reset_parameters(self) if hasattr(self, 'lora_A'): - nn.init.zeros_(self.lora_A) - nn.init.normal_(self.lora_B) + self.lora_A = nn.ParameterList() + nn.init.zeros_(self.lora_B.weight) def train(self, mode: bool = True): nn.Embedding.train(self, mode) if self.merge_weights and self.merged: if self.r > 0: - mask = ~(self.trainable_ids == -1) - trainable_ids = self.trainable_ids[mask] - self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling + weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) + if weights is not None: + self.weight[mask].data -= weights self.merged = False def eval(self): nn.Embedding.eval(self) if self.merge_weights and not self.merged: if self.r > 0: - mask = ~(self.trainable_ids == -1) - trainable_ids = self.trainable_ids[mask] - self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling + weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) + if weights is not None: + self.weight[mask].data += weights self.merged = True def forward(self, input_ids: torch.Tensor): result = nn.Embedding.forward(self, input_ids) if self.r > 0 and not self.merged: - trainable_ids = self.trainable_ids[input_ids] - mask = ~(trainable_ids == -1) - trainable_ids = trainable_ids[mask] - - after_A = F.embedding( - trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse - ) - result[mask] += (after_A @ self.lora_B.T) * self.scaling + weights, mask = self.get_weights(input_ids) + if weights is not None: + result[mask] += weights return result -- cgit v1.2.3-54-g00ecf