From 3924055ed24da9b6995303cd36282eb558ba0bf0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 14:45:37 +0200 Subject: Fix --- models/clip/embeddings.py | 41 ++++------------- models/lora.py | 77 ++++++++++++++++---------------- models/sparse.py | 110 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 71 deletions(-) create mode 100644 models/sparse.py (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index d02ccc3..8aaea8f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -10,23 +10,21 @@ from transformers import CLIPTextModel from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings -from models.lora import LoraEmbedding +from models.sparse import SparseEmbedding class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): super().__init__(config) self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - self.token_embedding = LoraEmbedding( + self.token_embedding = SparseEmbedding( self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, - r, - lora_alpha, - lora_dropout, + alpha, + dropout, ) - self.token_embedding.weight = embeddings.token_embedding.weight def resize(self, size: int): @@ -82,38 +80,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return self.token_embedding(input_ids) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> torch.Tensor: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - - if inputs_embeds is None: - inputs_embeds = self.get_embed(input_ids) - - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - - return embeddings - def patch_managed_embeddings( text_encoder: CLIPTextModel, - r: int = 8, - lora_alpha: int = 8, - lora_dropout: float = 0.0 + alpha: int = 8, + dropout: float = 0.0 ) -> ManagedCLIPTextEmbeddings: text_embeddings = ManagedCLIPTextEmbeddings( text_encoder.config, text_encoder.text_model.embeddings, - r, - lora_alpha, - lora_dropout + alpha, + dropout ) text_encoder.text_model.embeddings = text_embeddings return text_embeddings diff --git a/models/lora.py b/models/lora.py index 01a540b..e506cff 100644 --- a/models/lora.py +++ b/models/lora.py @@ -1,8 +1,8 @@ from typing import Optional +import math import torch import torch.nn as nn -import torch.nn.functional as F class LoraLayer(): @@ -42,14 +42,12 @@ class LoraEmbedding(nn.Embedding, LoraLayer): self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights ) - self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long)) - self.trainable_ids -= 1 + self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) - if r > 0: - self.lora_A = nn.ParameterList() - self.lora_B = nn.Linear(r, embedding_dim, bias=False) - self.scaling = self.lora_alpha / self.r - self.weight.requires_grad = False + self.lora_A = nn.ParameterList() + self.lora_B = nn.Linear(r, embedding_dim, bias=False) + self.scaling = self.lora_alpha / self.r + self.weight.requires_grad = False self.reset_parameters() @@ -70,8 +68,9 @@ class LoraEmbedding(nn.Embedding, LoraLayer): else: nn.init.zeros_(new_emb.weight.data) new_emb.weight.data[:n, :] = self.weight.data[:n, :] - new_emb.lora_A = self.lora_A - new_emb.lora_B = self.lora_B + for param in self.lora_A: + new_emb.lora_A.append(param) + new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data new_emb.trainable_ids[:n] = self.trainable_ids[:n] return new_emb @@ -87,60 +86,60 @@ class LoraEmbedding(nn.Embedding, LoraLayer): n2 = n1 + new_ids.shape[0] self.trainable_ids[new_ids] = torch.arange(n1, n2) for _ in new_ids: - self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r))) + w = self.weight.new_zeros(self.r) + self.lora_A.append(w) + + if len(self.lora_A) > 1: + elems = torch.stack([param for param in self.lora_A]) + nn.init.kaiming_uniform_(elems, a=math.sqrt(5)) def get_weights(self, input_ids: torch.Tensor): if len(input_ids.shape) != 1: return torch.stack([self.get_weights(batch) for batch in input_ids]) - trainable_ids = self.trainable_ids[input_ids] - mask = ~(trainable_ids == -1) - trainable_ids = trainable_ids[mask] - weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) - elems = [self.lora_A[id] for id in trainable_ids] - if len(elems) != 0: - w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling - weights[mask] = w.to(dtype=weights.dtype) + if not self.merged: + trainable_ids = self.trainable_ids[input_ids] + mask = ~(trainable_ids == -1) + elems = [self.lora_A[id] for id in trainable_ids[mask]] + + if len(elems) != 0: + w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling + weights[mask] = w.to(dtype=weights.dtype) return weights def persist(self): - if self.r > 0: - weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.weight.data += weights - self.trainable_ids[:] = -1 - self.lora_A = nn.ParameterList() + self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) + self.trainable_ids[:] = -1 + self.lora_A = nn.ParameterList() + nn.init.zeros_(self.lora_B.weight) def reset_parameters(self): nn.Embedding.reset_parameters(self) - if hasattr(self, 'lora_A'): + if hasattr(self, "lora_A"): self.trainable_ids[:] = -1 self.lora_A = nn.ParameterList() nn.init.zeros_(self.lora_B.weight) def train(self, mode: bool = True): nn.Embedding.train(self, mode) - if self.merge_weights and self.merged: - if self.r > 0: - weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.weight.data -= weights + self.lora_A.train(mode) + self.lora_B.train(mode) + if not mode and self.merge_weights and not self.merged: + self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) + self.merged = True + elif self.merge_weights and self.merged: + self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0])) self.merged = False def eval(self): nn.Embedding.eval(self) - if self.merge_weights and not self.merged: - if self.r > 0: - weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.weight.data += weights - self.merged = True + self.lora_A.eval() + self.lora_B.eval() def forward(self, input_ids: torch.LongTensor): result = nn.Embedding.forward(self, input_ids) - - if self.r > 0 and not self.merged: - weights = self.get_weights(input_ids) - result += weights - + result += self.get_weights(input_ids) return result diff --git a/models/sparse.py b/models/sparse.py new file mode 100644 index 0000000..bd45696 --- /dev/null +++ b/models/sparse.py @@ -0,0 +1,110 @@ +from typing import Optional + +import torch +import torch.nn as nn + + +class SparseEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + alpha: int = 1, + dropout: float = 0.0, + **kwargs + ): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) + + self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) + + self.trainable = nn.ParameterList() + self.scaling = alpha + self.dropout_p = dropout + self.weight.requires_grad = False + + if dropout > 0.: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = nn.Identity() + + self.reset_parameters() + + def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): + n = min(self.num_embeddings, new_num_embeddings) + + new_emb = SparseEmbedding( + new_num_embeddings, + self.embedding_dim, + self.scaling, + self.dropout_p, + device=self.weight.device, + dtype=self.weight.dtype + ) + if initializer_factor is not None: + new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) + else: + nn.init.zeros_(new_emb.weight.data) + new_emb.weight.data[:n, :] = self.weight.data[:n, :] + for param in self.trainable: + new_emb.trainable.append(param) + new_emb.trainable_ids[:n] = self.trainable_ids[:n] + + return new_emb + + def mark_trainable(self, input_ids: torch.LongTensor): + trainable_ids = self.trainable_ids[input_ids] + new_ids = input_ids[trainable_ids == -1] + + if new_ids.shape[0] == 0: + return + + n1 = len(self.trainable) + n2 = n1 + new_ids.shape[0] + self.trainable_ids[new_ids] = torch.arange(n1, n2) + for _ in new_ids: + self.trainable.append(self.weight.new_zeros(self.embedding_dim)) + + def get_weights(self, input_ids: torch.Tensor): + original_shape = input_ids.shape + + if len(input_ids.shape) != 1: + input_ids = input_ids.view(input_ids.shape[0] * input_ids.shape[1]) + + weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) + + trainable_ids = self.trainable_ids[input_ids] + mask = ~(trainable_ids == -1) + elems = [self.trainable[id] for id in trainable_ids[mask]] + + if len(elems) != 0: + w = self.dropout(torch.stack(elems)) * self.scaling + weights[mask] = w.to(dtype=weights.dtype) + + if len(original_shape) != 1: + weights = weights.view(original_shape[0], original_shape[1], -1) + + return weights + + def persist(self): + self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) + self.trainable_ids[:] = -1 + self.trainable = nn.ParameterList() + + def reset_parameters(self): + nn.Embedding.reset_parameters(self) + if hasattr(self, "trainable"): + self.trainable_ids[:] = -1 + self.trainable = nn.ParameterList() + + def train(self, mode: bool = True): + nn.Embedding.train(self, mode) + self.trainable.train(mode) + + def eval(self): + nn.Embedding.eval(self) + self.trainable.eval() + + def forward(self, input_ids: torch.LongTensor): + result = nn.Embedding.forward(self, input_ids) + result += self.get_weights(input_ids) + return result -- cgit v1.2.3-70-g09d2