From 15d1a15d1010509c8a2a6dd1ffa47b81e7bc0b78 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 10:37:04 +0200 Subject: Fix --- models/clip/embeddings.py | 2 +- models/lora.py | 42 ++++++++++++++++++++---------------------- 2 files changed, 21 insertions(+), 23 deletions(-) (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 4444cf9..d02ccc3 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -64,7 +64,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.token_embedding.mark_trainable(token_ids) - self.token_embedding.weight.data[token_ids] = initializer + self.token_embedding.weight[token_ids].data = initializer def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: diff --git a/models/lora.py b/models/lora.py index b7fa58f..a8197a5 100644 --- a/models/lora.py +++ b/models/lora.py @@ -42,7 +42,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights ) - self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) + self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long)) self.trainable_ids -= 1 if r > 0: @@ -76,7 +76,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): return new_emb - def mark_trainable(self, input_ids): + def mark_trainable(self, input_ids: torch.LongTensor): trainable_ids = self.trainable_ids[input_ids] new_ids = input_ids[trainable_ids == -1] @@ -87,15 +87,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): n2 = n1 + new_ids.shape[0] self.trainable_ids[new_ids] = torch.arange(n1, n2) for _ in new_ids: - self.lora_A.append(self.weight.new_zeros(self.r)) - - def persist(self): - if self.r > 0: - weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) - if weights is not None: - self.weight[mask].data += weights - self.trainable_ids[:] = -1 - self.lora_A = nn.ParameterList() + self.lora_A.append(self.weight.new_zeros(self.r, requires_grad=True)) def get_weights(self, input_ids: torch.Tensor): trainable_ids = self.trainable_ids[input_ids] @@ -104,16 +96,25 @@ class LoraEmbedding(nn.Embedding, LoraLayer): elems = [self.lora_A[id] for id in trainable_ids] - if len(elems) == 0: - return None, mask - - weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling + if len(elems) != 0: + weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling + else: + weights = self.weight.new_zeros(self.embedding_dim) return weights, mask + def persist(self): + if self.r > 0: + weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) + if weights is not None: + self.weight[mask].data += weights + self.trainable_ids[:] = -1 + self.lora_A = nn.ParameterList() + def reset_parameters(self): nn.Embedding.reset_parameters(self) if hasattr(self, 'lora_A'): + self.trainable_ids[:] = -1 self.lora_A = nn.ParameterList() nn.init.zeros_(self.lora_B.weight) @@ -122,8 +123,7 @@ class LoraEmbedding(nn.Embedding, LoraLayer): if self.merge_weights and self.merged: if self.r > 0: weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) - if weights is not None: - self.weight[mask].data -= weights + self.weight[mask].data -= weights self.merged = False def eval(self): @@ -131,16 +131,14 @@ class LoraEmbedding(nn.Embedding, LoraLayer): if self.merge_weights and not self.merged: if self.r > 0: weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) - if weights is not None: - self.weight[mask].data += weights + self.weight[mask].data += weights self.merged = True - def forward(self, input_ids: torch.Tensor): + def forward(self, input_ids: torch.LongTensor): result = nn.Embedding.forward(self, input_ids) if self.r > 0 and not self.merged: weights, mask = self.get_weights(input_ids) - if weights is not None: - result[mask] += weights + result[mask] += weights return result -- cgit v1.2.3-70-g09d2