From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- models/clip/embeddings.py | 76 ++++++++++++++++------------------------------- 1 file changed, 26 insertions(+), 50 deletions(-) (limited to 'models/clip/embeddings.py') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9be8256..60c1b20 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -11,49 +11,27 @@ from transformers import CLIPTextModel from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings -from models.sparse import PseudoSparseEmbedding - - -def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: - old_num_embeddings, old_embedding_dim = old_embedding.weight.shape - - if old_num_embeddings == new_num_embeddings: - return old_embedding - - n = min(old_num_embeddings, new_num_embeddings) - - new_embedding = nn.Embedding( - new_num_embeddings, - old_embedding_dim, - device=old_embedding.weight.device, - dtype=old_embedding.weight.dtype - ) - if initializer_factor is not None: - new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) - else: - nn.init.zeros_(new_embedding.weight.data) - new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] - return new_embedding +from models.lora import LoraEmbedding class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): super().__init__(config) - self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - - self.token_override_embedding = PseudoSparseEmbedding( + self.token_embedding = LoraEmbedding( + self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, - dropout_p=dropout_p, - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype, + r, + lora_alpha, + lora_dropout, ) + self.token_embedding.weight = embeddings.token_embedding.weight + def resize(self, size: int): - self.token_override_embedding.resize(size) - self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) + self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) def add_embed( self, @@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.token_embedding.weight.data[token_ids] = initializer - self.token_override_embedding.set(token_ids, initializer) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: @@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - input_ids = torch.arange( - self.token_embedding.num_embeddings, - device=self.token_override_embedding.mapping.device - ) - embs, mask = self.token_override_embedding(input_ids) - if embs is not None: - input_ids = input_ids[mask] - self.token_embedding.weight.data[input_ids] = embs - self.token_override_embedding.unset(input_ids) + self.token_embedding.eval() + self.token_embedding.merged = False def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) - embs = self.token_embedding(input_ids) - embs_override, mask = self.token_override_embedding(input_ids) - if embs_override is not None: - embs[mask] = embs_override - - return embs + return self.token_embedding(input_ids) def forward( self, @@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeddings -def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: - text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) +def patch_managed_embeddings( + text_encoder: CLIPTextModel, + r: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0 +) -> ManagedCLIPTextEmbeddings: + text_embeddings = ManagedCLIPTextEmbeddings( + text_encoder.config, + text_encoder.text_model.embeddings, + r, + lora_alpha, + lora_dropout + ) text_encoder.text_model.embeddings = text_embeddings return text_embeddings -- cgit v1.2.3-70-g09d2