diff options
-rw-r--r-- | models/lora.py | 34 |
1 files changed, 18 insertions, 16 deletions
diff --git a/models/lora.py b/models/lora.py index a8197a5..01a540b 100644 --- a/models/lora.py +++ b/models/lora.py | |||
@@ -87,29 +87,31 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
87 | n2 = n1 + new_ids.shape[0] | 87 | n2 = n1 + new_ids.shape[0] |
88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) |
89 | for _ in new_ids: | 89 | for _ in new_ids: |
90 | self.lora_A.append(self.weight.new_zeros(self.r, requires_grad=True)) | 90 | self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r))) |
91 | 91 | ||
92 | def get_weights(self, input_ids: torch.Tensor): | 92 | def get_weights(self, input_ids: torch.Tensor): |
93 | if len(input_ids.shape) != 1: | ||
94 | return torch.stack([self.get_weights(batch) for batch in input_ids]) | ||
95 | |||
93 | trainable_ids = self.trainable_ids[input_ids] | 96 | trainable_ids = self.trainable_ids[input_ids] |
94 | mask = ~(trainable_ids == -1) | 97 | mask = ~(trainable_ids == -1) |
95 | trainable_ids = trainable_ids[mask] | 98 | trainable_ids = trainable_ids[mask] |
96 | 99 | ||
100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | ||
97 | elems = [self.lora_A[id] for id in trainable_ids] | 101 | elems = [self.lora_A[id] for id in trainable_ids] |
98 | 102 | ||
99 | if len(elems) != 0: | 103 | if len(elems) != 0: |
100 | weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | 104 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling |
101 | else: | 105 | weights[mask] = w.to(dtype=weights.dtype) |
102 | weights = self.weight.new_zeros(self.embedding_dim) | ||
103 | 106 | ||
104 | return weights, mask | 107 | return weights |
105 | 108 | ||
106 | def persist(self): | 109 | def persist(self): |
107 | if self.r > 0: | 110 | if self.r > 0: |
108 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 111 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
109 | if weights is not None: | 112 | self.weight.data += weights |
110 | self.weight[mask].data += weights | 113 | self.trainable_ids[:] = -1 |
111 | self.trainable_ids[:] = -1 | 114 | self.lora_A = nn.ParameterList() |
112 | self.lora_A = nn.ParameterList() | ||
113 | 115 | ||
114 | def reset_parameters(self): | 116 | def reset_parameters(self): |
115 | nn.Embedding.reset_parameters(self) | 117 | nn.Embedding.reset_parameters(self) |
@@ -122,23 +124,23 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
122 | nn.Embedding.train(self, mode) | 124 | nn.Embedding.train(self, mode) |
123 | if self.merge_weights and self.merged: | 125 | if self.merge_weights and self.merged: |
124 | if self.r > 0: | 126 | if self.r > 0: |
125 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 127 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
126 | self.weight[mask].data -= weights | 128 | self.weight.data -= weights |
127 | self.merged = False | 129 | self.merged = False |
128 | 130 | ||
129 | def eval(self): | 131 | def eval(self): |
130 | nn.Embedding.eval(self) | 132 | nn.Embedding.eval(self) |
131 | if self.merge_weights and not self.merged: | 133 | if self.merge_weights and not self.merged: |
132 | if self.r > 0: | 134 | if self.r > 0: |
133 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 135 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
134 | self.weight[mask].data += weights | 136 | self.weight.data += weights |
135 | self.merged = True | 137 | self.merged = True |
136 | 138 | ||
137 | def forward(self, input_ids: torch.LongTensor): | 139 | def forward(self, input_ids: torch.LongTensor): |
138 | result = nn.Embedding.forward(self, input_ids) | 140 | result = nn.Embedding.forward(self, input_ids) |
139 | 141 | ||
140 | if self.r > 0 and not self.merged: | 142 | if self.r > 0 and not self.merged: |
141 | weights, mask = self.get_weights(input_ids) | 143 | weights = self.get_weights(input_ids) |
142 | result[mask] += weights | 144 | result += weights |
143 | 145 | ||
144 | return result | 146 | return result |