From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- models/clip/embeddings.py | 76 +++++++++------------------ models/lora.py | 131 ++++++++++++++++++++++++++++++++++++++++++++++ models/sparse.py | 66 ----------------------- 3 files changed, 157 insertions(+), 116 deletions(-) create mode 100644 models/lora.py delete mode 100644 models/sparse.py (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9be8256..60c1b20 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -11,49 +11,27 @@ from transformers import CLIPTextModel from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings -from models.sparse import PseudoSparseEmbedding - - -def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: - old_num_embeddings, old_embedding_dim = old_embedding.weight.shape - - if old_num_embeddings == new_num_embeddings: - return old_embedding - - n = min(old_num_embeddings, new_num_embeddings) - - new_embedding = nn.Embedding( - new_num_embeddings, - old_embedding_dim, - device=old_embedding.weight.device, - dtype=old_embedding.weight.dtype - ) - if initializer_factor is not None: - new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) - else: - nn.init.zeros_(new_embedding.weight.data) - new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] - return new_embedding +from models.lora import LoraEmbedding class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): super().__init__(config) - self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - - self.token_override_embedding = PseudoSparseEmbedding( + self.token_embedding = LoraEmbedding( + self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, - dropout_p=dropout_p, - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype, + r, + lora_alpha, + lora_dropout, ) + self.token_embedding.weight = embeddings.token_embedding.weight + def resize(self, size: int): - self.token_override_embedding.resize(size) - self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) + self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) def add_embed( self, @@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.token_embedding.weight.data[token_ids] = initializer - self.token_override_embedding.set(token_ids, initializer) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: @@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - input_ids = torch.arange( - self.token_embedding.num_embeddings, - device=self.token_override_embedding.mapping.device - ) - embs, mask = self.token_override_embedding(input_ids) - if embs is not None: - input_ids = input_ids[mask] - self.token_embedding.weight.data[input_ids] = embs - self.token_override_embedding.unset(input_ids) + self.token_embedding.eval() + self.token_embedding.merged = False def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) - embs = self.token_embedding(input_ids) - embs_override, mask = self.token_override_embedding(input_ids) - if embs_override is not None: - embs[mask] = embs_override - - return embs + return self.token_embedding(input_ids) def forward( self, @@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeddings -def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: - text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) +def patch_managed_embeddings( + text_encoder: CLIPTextModel, + r: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0 +) -> ManagedCLIPTextEmbeddings: + text_embeddings = ManagedCLIPTextEmbeddings( + text_encoder.config, + text_encoder.text_model.embeddings, + r, + lora_alpha, + lora_dropout + ) text_encoder.text_model.embeddings = text_embeddings return text_embeddings diff --git a/models/lora.py b/models/lora.py new file mode 100644 index 0000000..c0f74a6 --- /dev/null +++ b/models/lora.py @@ -0,0 +1,131 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LoraLayer(): + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + self.lora_dropout_p = lora_dropout + + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = nn.Identity() + + self.merged = False + self.merge_weights = merge_weights + + +class LoraEmbedding(nn.Embedding, LoraLayer): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + merge_weights: bool = True, + **kwargs + ): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) + LoraLayer.__init__( + self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights + ) + + self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) + self.trainable_ids -= 1 + + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) + self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) + self.scaling = self.lora_alpha / self.r + self.weight.requires_grad = False + + self.reset_parameters() + + def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): + n = min(self.num_embeddings, new_num_embeddings) + + new_emb = LoraEmbedding( + new_num_embeddings, + self.embedding_dim, + self.r, + self.lora_alpha, + self.lora_dropout_p, + device=self.weight.device, + dtype=self.weight.dtype + ) + if initializer_factor is not None: + new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) + else: + nn.init.zeros_(new_emb.weight.data) + new_emb.weight.data[:n, :] = self.weight.data[:n, :] + new_emb.lora_A = self.lora_A + new_emb.lora_B = self.lora_B + new_emb.trainable_ids[:n] = self.trainable_ids[:n] + + return new_emb + + def mark_trainable(self, input_ids): + trainable_ids = self.trainable_ids[input_ids] + new_ids = trainable_ids[trainable_ids == -1] + + if new_ids.shape[0] == 0: + return + + n = self.trainable_ids.shape[0] + self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) + + lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) + lora_A.data[:n] = self.lora_A.data + self.lora_A = lora_A + + def reset_parameters(self): + nn.Embedding.reset_parameters(self) + if hasattr(self, 'lora_A'): + nn.init.zeros_(self.lora_A) + nn.init.normal_(self.lora_B) + + def train(self, mode: bool = True): + nn.Embedding.train(self, mode) + if self.merge_weights and self.merged: + if self.r > 0: + mask = ~(self.trainable_ids == -1) + trainable_ids = self.trainable_ids[mask] + self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling + self.merged = False + + def eval(self): + nn.Embedding.eval(self) + if self.merge_weights and not self.merged: + if self.r > 0: + mask = ~(self.trainable_ids == -1) + trainable_ids = self.trainable_ids[mask] + self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, input_ids: torch.Tensor): + result = nn.Embedding.forward(self, input_ids) + + if self.r > 0 and not self.merged: + trainable_ids = self.trainable_ids[input_ids] + mask = ~(trainable_ids == -1) + trainable_ids = trainable_ids[mask] + + after_A = F.embedding( + trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse + ) + result[mask] += (after_A @ self.lora_B.T) * self.scaling + + return result diff --git a/models/sparse.py b/models/sparse.py deleted file mode 100644 index 07b3413..0000000 --- a/models/sparse.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn - - -class PseudoSparseEmbedding(nn.Module): - def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): - super().__init__() - - self.embedding_dim = embedding_dim - self.dtype = dtype - self.params = nn.ParameterList() - - if dropout_p > 0.0: - self.dropout = nn.Dropout(p=dropout_p) - else: - self.dropout = nn.Identity() - - self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) - - def forward(self, input_ids: torch.LongTensor): - input_ids = input_ids.to(self.mapping.device) - ids = self.mapping[input_ids] - mask = ~(ids == -1) - - if torch.all(~mask): - embs = None - else: - embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) - - return embs, mask - - def resize(self, new_num_embeddings: int): - old_num_embeddings = self.mapping.shape[0] - n = min(old_num_embeddings, new_num_embeddings) - - new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 - new_mapping[:n] = self.mapping[:n] - - self.mapping = new_mapping - - def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): - if len(input_ids.shape) != 0: - if tensor is not None: - return [self.set(id, t) for id, t in zip(input_ids, tensor)] - else: - return [self.set(id) for id in input_ids] - - if tensor is None: - tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) - - if tensor.shape[-1] != self.embedding_dim: - raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") - - id = self.mapping[input_ids] - - if id == -1: - id = len(self.params) - self.mapping[input_ids] = id - self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) - - self.params[id] = tensor - - def unset(self, input_ids: torch.LongTensor): - self.mapping[input_ids] = -1 -- cgit v1.2.3-70-g09d2