From 3924055ed24da9b6995303cd36282eb558ba0bf0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 14:45:37 +0200 Subject: Fix --- models/clip/embeddings.py | 41 +++++++++-------------------------------- 1 file changed, 9 insertions(+), 32 deletions(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index d02ccc3..8aaea8f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -10,23 +10,21 @@ from transformers import CLIPTextModel from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings -from models.lora import LoraEmbedding +from models.sparse import SparseEmbedding class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): super().__init__(config) self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - self.token_embedding = LoraEmbedding( + self.token_embedding = SparseEmbedding( self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, - r, - lora_alpha, - lora_dropout, + alpha, + dropout, ) - self.token_embedding.weight = embeddings.token_embedding.weight def resize(self, size: int): @@ -82,38 +80,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return self.token_embedding(input_ids) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> torch.Tensor: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - - if inputs_embeds is None: - inputs_embeds = self.get_embed(input_ids) - - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - - return embeddings - def patch_managed_embeddings( text_encoder: CLIPTextModel, - r: int = 8, - lora_alpha: int = 8, - lora_dropout: float = 0.0 + alpha: int = 8, + dropout: float = 0.0 ) -> ManagedCLIPTextEmbeddings: text_embeddings = ManagedCLIPTextEmbeddings( text_encoder.config, text_encoder.text_model.embeddings, - r, - lora_alpha, - lora_dropout + alpha, + dropout ) text_encoder.text_model.embeddings = text_embeddings return text_embeddings -- cgit v1.2.3-54-g00ecf