diff options
Diffstat (limited to 'models')
| -rw-r--r-- | models/lora.py | 34 |
1 files changed, 18 insertions, 16 deletions
diff --git a/models/lora.py b/models/lora.py index a8197a5..01a540b 100644 --- a/models/lora.py +++ b/models/lora.py | |||
| @@ -87,29 +87,31 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 87 | n2 = n1 + new_ids.shape[0] | 87 | n2 = n1 + new_ids.shape[0] |
| 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) |
| 89 | for _ in new_ids: | 89 | for _ in new_ids: |
| 90 | self.lora_A.append(self.weight.new_zeros(self.r, requires_grad=True)) | 90 | self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r))) |
| 91 | 91 | ||
| 92 | def get_weights(self, input_ids: torch.Tensor): | 92 | def get_weights(self, input_ids: torch.Tensor): |
| 93 | if len(input_ids.shape) != 1: | ||
| 94 | return torch.stack([self.get_weights(batch) for batch in input_ids]) | ||
| 95 | |||
| 93 | trainable_ids = self.trainable_ids[input_ids] | 96 | trainable_ids = self.trainable_ids[input_ids] |
| 94 | mask = ~(trainable_ids == -1) | 97 | mask = ~(trainable_ids == -1) |
| 95 | trainable_ids = trainable_ids[mask] | 98 | trainable_ids = trainable_ids[mask] |
| 96 | 99 | ||
| 100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | ||
| 97 | elems = [self.lora_A[id] for id in trainable_ids] | 101 | elems = [self.lora_A[id] for id in trainable_ids] |
| 98 | 102 | ||
| 99 | if len(elems) != 0: | 103 | if len(elems) != 0: |
| 100 | weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | 104 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling |
| 101 | else: | 105 | weights[mask] = w.to(dtype=weights.dtype) |
| 102 | weights = self.weight.new_zeros(self.embedding_dim) | ||
| 103 | 106 | ||
| 104 | return weights, mask | 107 | return weights |
| 105 | 108 | ||
| 106 | def persist(self): | 109 | def persist(self): |
| 107 | if self.r > 0: | 110 | if self.r > 0: |
| 108 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 111 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 109 | if weights is not None: | 112 | self.weight.data += weights |
| 110 | self.weight[mask].data += weights | 113 | self.trainable_ids[:] = -1 |
| 111 | self.trainable_ids[:] = -1 | 114 | self.lora_A = nn.ParameterList() |
| 112 | self.lora_A = nn.ParameterList() | ||
| 113 | 115 | ||
| 114 | def reset_parameters(self): | 116 | def reset_parameters(self): |
| 115 | nn.Embedding.reset_parameters(self) | 117 | nn.Embedding.reset_parameters(self) |
| @@ -122,23 +124,23 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 122 | nn.Embedding.train(self, mode) | 124 | nn.Embedding.train(self, mode) |
| 123 | if self.merge_weights and self.merged: | 125 | if self.merge_weights and self.merged: |
| 124 | if self.r > 0: | 126 | if self.r > 0: |
| 125 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 127 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 126 | self.weight[mask].data -= weights | 128 | self.weight.data -= weights |
| 127 | self.merged = False | 129 | self.merged = False |
| 128 | 130 | ||
| 129 | def eval(self): | 131 | def eval(self): |
| 130 | nn.Embedding.eval(self) | 132 | nn.Embedding.eval(self) |
| 131 | if self.merge_weights and not self.merged: | 133 | if self.merge_weights and not self.merged: |
| 132 | if self.r > 0: | 134 | if self.r > 0: |
| 133 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 135 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 134 | self.weight[mask].data += weights | 136 | self.weight.data += weights |
| 135 | self.merged = True | 137 | self.merged = True |
| 136 | 138 | ||
| 137 | def forward(self, input_ids: torch.LongTensor): | 139 | def forward(self, input_ids: torch.LongTensor): |
| 138 | result = nn.Embedding.forward(self, input_ids) | 140 | result = nn.Embedding.forward(self, input_ids) |
| 139 | 141 | ||
| 140 | if self.r > 0 and not self.merged: | 142 | if self.r > 0 and not self.merged: |
| 141 | weights, mask = self.get_weights(input_ids) | 143 | weights = self.get_weights(input_ids) |
| 142 | result[mask] += weights | 144 | result += weights |
| 143 | 145 | ||
| 144 | return result | 146 | return result |
