From 6aa4b5e199f0028db74b21646432b2203e0888d8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 11:21:23 +0200 Subject: Fix --- models/lora.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) (limited to 'models/lora.py') diff --git a/models/lora.py b/models/lora.py index a8197a5..01a540b 100644 --- a/models/lora.py +++ b/models/lora.py @@ -87,29 +87,31 @@ class LoraEmbedding(nn.Embedding, LoraLayer): n2 = n1 + new_ids.shape[0] self.trainable_ids[new_ids] = torch.arange(n1, n2) for _ in new_ids: - self.lora_A.append(self.weight.new_zeros(self.r, requires_grad=True)) + self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r))) def get_weights(self, input_ids: torch.Tensor): + if len(input_ids.shape) != 1: + return torch.stack([self.get_weights(batch) for batch in input_ids]) + trainable_ids = self.trainable_ids[input_ids] mask = ~(trainable_ids == -1) trainable_ids = trainable_ids[mask] + weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) elems = [self.lora_A[id] for id in trainable_ids] if len(elems) != 0: - weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling - else: - weights = self.weight.new_zeros(self.embedding_dim) + w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling + weights[mask] = w.to(dtype=weights.dtype) - return weights, mask + return weights def persist(self): if self.r > 0: - weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) - if weights is not None: - self.weight[mask].data += weights - self.trainable_ids[:] = -1 - self.lora_A = nn.ParameterList() + weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) + self.weight.data += weights + self.trainable_ids[:] = -1 + self.lora_A = nn.ParameterList() def reset_parameters(self): nn.Embedding.reset_parameters(self) @@ -122,23 +124,23 @@ class LoraEmbedding(nn.Embedding, LoraLayer): nn.Embedding.train(self, mode) if self.merge_weights and self.merged: if self.r > 0: - weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.weight[mask].data -= weights + weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) + self.weight.data -= weights self.merged = False def eval(self): nn.Embedding.eval(self) if self.merge_weights and not self.merged: if self.r > 0: - weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.weight[mask].data += weights + weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) + self.weight.data += weights self.merged = True def forward(self, input_ids: torch.LongTensor): result = nn.Embedding.forward(self, input_ids) if self.r > 0 and not self.merged: - weights, mask = self.get_weights(input_ids) - result[mask] += weights + weights = self.get_weights(input_ids) + result += weights return result -- cgit v1.2.3-54-g00ecf