summaryrefslogtreecommitdiffstats
path: root/models/lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/lora.py')
-rw-r--r--models/lora.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/models/lora.py b/models/lora.py
index c0f74a6..98d4d2c 100644
--- a/models/lora.py
+++ b/models/lora.py
@@ -83,11 +83,11 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
83 if new_ids.shape[0] == 0: 83 if new_ids.shape[0] == 0:
84 return 84 return
85 85
86 n = self.trainable_ids.shape[0] 86 n1 = self.lora_A.shape[1]
87 self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) 87 n2 = n1 + new_ids.shape[0]
88 self.trainable_ids[new_ids] = torch.arange(n1, n2)
88 89
89 lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) 90 lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2)))
90 lora_A.data[:n] = self.lora_A.data
91 self.lora_A = lora_A 91 self.lora_A = lora_A
92 92
93 def reset_parameters(self): 93 def reset_parameters(self):