From d488f66c78e444d03c4ef8a957b82f8b239379d0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:31:24 +0200 Subject: Fix --- models/lora.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'models/lora.py') diff --git a/models/lora.py b/models/lora.py index c0f74a6..98d4d2c 100644 --- a/models/lora.py +++ b/models/lora.py @@ -83,11 +83,11 @@ class LoraEmbedding(nn.Embedding, LoraLayer): if new_ids.shape[0] == 0: return - n = self.trainable_ids.shape[0] - self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) + n1 = self.lora_A.shape[1] + n2 = n1 + new_ids.shape[0] + self.trainable_ids[new_ids] = torch.arange(n1, n2) - lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) - lora_A.data[:n] = self.lora_A.data + lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2))) self.lora_A = lora_A def reset_parameters(self): -- cgit v1.2.3-54-g00ecf