summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/lora.py34
1 files changed, 18 insertions, 16 deletions
diff --git a/models/lora.py b/models/lora.py
index a8197a5..01a540b 100644
--- a/models/lora.py
+++ b/models/lora.py
@@ -87,29 +87,31 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
87 n2 = n1 + new_ids.shape[0] 87 n2 = n1 + new_ids.shape[0]
88 self.trainable_ids[new_ids] = torch.arange(n1, n2) 88 self.trainable_ids[new_ids] = torch.arange(n1, n2)
89 for _ in new_ids: 89 for _ in new_ids:
90 self.lora_A.append(self.weight.new_zeros(self.r, requires_grad=True)) 90 self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r)))
91 91
92 def get_weights(self, input_ids: torch.Tensor): 92 def get_weights(self, input_ids: torch.Tensor):
93 if len(input_ids.shape) != 1:
94 return torch.stack([self.get_weights(batch) for batch in input_ids])
95
93 trainable_ids = self.trainable_ids[input_ids] 96 trainable_ids = self.trainable_ids[input_ids]
94 mask = ~(trainable_ids == -1) 97 mask = ~(trainable_ids == -1)
95 trainable_ids = trainable_ids[mask] 98 trainable_ids = trainable_ids[mask]
96 99
100 weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim))
97 elems = [self.lora_A[id] for id in trainable_ids] 101 elems = [self.lora_A[id] for id in trainable_ids]
98 102
99 if len(elems) != 0: 103 if len(elems) != 0:
100 weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling 104 w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
101 else: 105 weights[mask] = w.to(dtype=weights.dtype)
102 weights = self.weight.new_zeros(self.embedding_dim)
103 106
104 return weights, mask 107 return weights
105 108
106 def persist(self): 109 def persist(self):
107 if self.r > 0: 110 if self.r > 0:
108 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) 111 weights = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
109 if weights is not None: 112 self.weight.data += weights
110 self.weight[mask].data += weights 113 self.trainable_ids[:] = -1
111 self.trainable_ids[:] = -1 114 self.lora_A = nn.ParameterList()
112 self.lora_A = nn.ParameterList()
113 115
114 def reset_parameters(self): 116 def reset_parameters(self):
115 nn.Embedding.reset_parameters(self) 117 nn.Embedding.reset_parameters(self)
@@ -122,23 +124,23 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
122 nn.Embedding.train(self, mode) 124 nn.Embedding.train(self, mode)
123 if self.merge_weights and self.merged: 125 if self.merge_weights and self.merged:
124 if self.r > 0: 126 if self.r > 0:
125 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) 127 weights = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
126 self.weight[mask].data -= weights 128 self.weight.data -= weights
127 self.merged = False 129 self.merged = False
128 130
129 def eval(self): 131 def eval(self):
130 nn.Embedding.eval(self) 132 nn.Embedding.eval(self)
131 if self.merge_weights and not self.merged: 133 if self.merge_weights and not self.merged:
132 if self.r > 0: 134 if self.r > 0:
133 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) 135 weights = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
134 self.weight[mask].data += weights 136 self.weight.data += weights
135 self.merged = True 137 self.merged = True
136 138
137 def forward(self, input_ids: torch.LongTensor): 139 def forward(self, input_ids: torch.LongTensor):
138 result = nn.Embedding.forward(self, input_ids) 140 result = nn.Embedding.forward(self, input_ids)
139 141
140 if self.r > 0 and not self.merged: 142 if self.r > 0 and not self.merged:
141 weights, mask = self.get_weights(input_ids) 143 weights = self.get_weights(input_ids)
142 result[mask] += weights 144 result += weights
143 145
144 return result 146 return result