summaryrefslogtreecommitdiffstats
path: root/models/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-15 13:31:24 +0200
committerVolpeon <git@volpeon.ink>2023-04-15 13:31:24 +0200
commitd488f66c78e444d03c4ef8a957b82f8b239379d0 (patch)
tree864b2fe8d03b0cdfc3437622a0dcd5a1ede60e16 /models/lora.py
parentTI via LoRA (diff)
downloadtextual-inversion-diff-d488f66c78e444d03c4ef8a957b82f8b239379d0.tar.gz
textual-inversion-diff-d488f66c78e444d03c4ef8a957b82f8b239379d0.tar.bz2
textual-inversion-diff-d488f66c78e444d03c4ef8a957b82f8b239379d0.zip
Fix
Diffstat (limited to 'models/lora.py')
-rw-r--r--models/lora.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/models/lora.py b/models/lora.py
index c0f74a6..98d4d2c 100644
--- a/models/lora.py
+++ b/models/lora.py
@@ -83,11 +83,11 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
83 if new_ids.shape[0] == 0: 83 if new_ids.shape[0] == 0:
84 return 84 return
85 85
86 n = self.trainable_ids.shape[0] 86 n1 = self.lora_A.shape[1]
87 self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) 87 n2 = n1 + new_ids.shape[0]
88 self.trainable_ids[new_ids] = torch.arange(n1, n2)
88 89
89 lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) 90 lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2)))
90 lora_A.data[:n] = self.lora_A.data
91 self.lora_A = lora_A 91 self.lora_A = lora_A
92 92
93 def reset_parameters(self): 93 def reset_parameters(self):