diff options
| -rw-r--r-- | models/clip/embeddings.py | 2 | ||||
| -rw-r--r-- | models/lora.py | 42 |
2 files changed, 21 insertions, 23 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 4444cf9..d02ccc3 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -64,7 +64,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 64 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 64 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 65 | 65 | ||
| 66 | self.token_embedding.mark_trainable(token_ids) | 66 | self.token_embedding.mark_trainable(token_ids) |
| 67 | self.token_embedding.weight.data[token_ids] = initializer | 67 | self.token_embedding.weight[token_ids].data = initializer |
| 68 | 68 | ||
| 69 | def load_embed(self, input_ids: list[int], filename: Path): | 69 | def load_embed(self, input_ids: list[int], filename: Path): |
| 70 | with safe_open(filename, framework="pt", device="cpu") as file: | 70 | with safe_open(filename, framework="pt", device="cpu") as file: |
diff --git a/models/lora.py b/models/lora.py index b7fa58f..a8197a5 100644 --- a/models/lora.py +++ b/models/lora.py | |||
| @@ -42,7 +42,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights |
| 43 | ) | 43 | ) |
| 44 | 44 | ||
| 45 | self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) | 45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long)) |
| 46 | self.trainable_ids -= 1 | 46 | self.trainable_ids -= 1 |
| 47 | 47 | ||
| 48 | if r > 0: | 48 | if r > 0: |
| @@ -76,7 +76,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 76 | 76 | ||
| 77 | return new_emb | 77 | return new_emb |
| 78 | 78 | ||
| 79 | def mark_trainable(self, input_ids): | 79 | def mark_trainable(self, input_ids: torch.LongTensor): |
| 80 | trainable_ids = self.trainable_ids[input_ids] | 80 | trainable_ids = self.trainable_ids[input_ids] |
| 81 | new_ids = input_ids[trainable_ids == -1] | 81 | new_ids = input_ids[trainable_ids == -1] |
| 82 | 82 | ||
| @@ -87,15 +87,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 87 | n2 = n1 + new_ids.shape[0] | 87 | n2 = n1 + new_ids.shape[0] |
| 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) |
| 89 | for _ in new_ids: | 89 | for _ in new_ids: |
| 90 | self.lora_A.append(self.weight.new_zeros(self.r)) | 90 | self.lora_A.append(self.weight.new_zeros(self.r, requires_grad=True)) |
| 91 | |||
| 92 | def persist(self): | ||
| 93 | if self.r > 0: | ||
| 94 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 95 | if weights is not None: | ||
| 96 | self.weight[mask].data += weights | ||
| 97 | self.trainable_ids[:] = -1 | ||
| 98 | self.lora_A = nn.ParameterList() | ||
| 99 | 91 | ||
| 100 | def get_weights(self, input_ids: torch.Tensor): | 92 | def get_weights(self, input_ids: torch.Tensor): |
| 101 | trainable_ids = self.trainable_ids[input_ids] | 93 | trainable_ids = self.trainable_ids[input_ids] |
| @@ -104,16 +96,25 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 104 | 96 | ||
| 105 | elems = [self.lora_A[id] for id in trainable_ids] | 97 | elems = [self.lora_A[id] for id in trainable_ids] |
| 106 | 98 | ||
| 107 | if len(elems) == 0: | 99 | if len(elems) != 0: |
| 108 | return None, mask | 100 | weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling |
| 109 | 101 | else: | |
| 110 | weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | 102 | weights = self.weight.new_zeros(self.embedding_dim) |
| 111 | 103 | ||
| 112 | return weights, mask | 104 | return weights, mask |
| 113 | 105 | ||
| 106 | def persist(self): | ||
| 107 | if self.r > 0: | ||
| 108 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 109 | if weights is not None: | ||
| 110 | self.weight[mask].data += weights | ||
| 111 | self.trainable_ids[:] = -1 | ||
| 112 | self.lora_A = nn.ParameterList() | ||
| 113 | |||
| 114 | def reset_parameters(self): | 114 | def reset_parameters(self): |
| 115 | nn.Embedding.reset_parameters(self) | 115 | nn.Embedding.reset_parameters(self) |
| 116 | if hasattr(self, 'lora_A'): | 116 | if hasattr(self, 'lora_A'): |
| 117 | self.trainable_ids[:] = -1 | ||
| 117 | self.lora_A = nn.ParameterList() | 118 | self.lora_A = nn.ParameterList() |
| 118 | nn.init.zeros_(self.lora_B.weight) | 119 | nn.init.zeros_(self.lora_B.weight) |
| 119 | 120 | ||
| @@ -122,8 +123,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 122 | if self.merge_weights and self.merged: | 123 | if self.merge_weights and self.merged: |
| 123 | if self.r > 0: | 124 | if self.r > 0: |
| 124 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 125 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 125 | if weights is not None: | 126 | self.weight[mask].data -= weights |
| 126 | self.weight[mask].data -= weights | ||
| 127 | self.merged = False | 127 | self.merged = False |
| 128 | 128 | ||
| 129 | def eval(self): | 129 | def eval(self): |
| @@ -131,16 +131,14 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 131 | if self.merge_weights and not self.merged: | 131 | if self.merge_weights and not self.merged: |
| 132 | if self.r > 0: | 132 | if self.r > 0: |
| 133 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 133 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 134 | if weights is not None: | 134 | self.weight[mask].data += weights |
| 135 | self.weight[mask].data += weights | ||
| 136 | self.merged = True | 135 | self.merged = True |
| 137 | 136 | ||
| 138 | def forward(self, input_ids: torch.Tensor): | 137 | def forward(self, input_ids: torch.LongTensor): |
| 139 | result = nn.Embedding.forward(self, input_ids) | 138 | result = nn.Embedding.forward(self, input_ids) |
| 140 | 139 | ||
| 141 | if self.r > 0 and not self.merged: | 140 | if self.r > 0 and not self.merged: |
| 142 | weights, mask = self.get_weights(input_ids) | 141 | weights, mask = self.get_weights(input_ids) |
| 143 | if weights is not None: | 142 | result[mask] += weights |
| 144 | result[mask] += weights | ||
| 145 | 143 | ||
| 146 | return result | 144 | return result |
