From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- models/clip/embeddings.py | 29 +++++++++++++++++------------ models/clip/tokenizer.py | 23 ++++++++++++----------- models/clip/util.py | 17 ++++++++++++----- 3 files changed, 41 insertions(+), 28 deletions(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 7c7f2ac..8c3c6d4 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -14,7 +14,13 @@ from models.sparse import SparseEmbedding class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): + def __init__( + self, + config: CLIPTextConfig, + embeddings: CLIPTextEmbeddings, + alpha: int = 8, + dropout: float = 0.0, + ): super().__init__(config) self.position_embedding = embeddings.position_embedding @@ -28,7 +34,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_embedding.weight = embeddings.token_embedding.weight def resize(self, size: int): - self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) + self.token_embedding = self.token_embedding.new_resized( + size, self.initializer_factor + ) def add_embed( self, @@ -46,7 +54,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): initializer = [initializer] if isinstance(initializer, list): - initializer = (initializer * len(token_ids))[:len(token_ids)] + initializer = (initializer * len(token_ids))[: len(token_ids)] with torch.no_grad(): initializer = self.get_embed(initializer) @@ -76,24 +84,21 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): - input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) + input_ids = torch.tensor( + input_ids, device=self.token_embedding.weight.device, dtype=torch.long + ) return self.token_embedding(input_ids) def patch_managed_embeddings( - text_encoder: CLIPTextModel, - alpha: int = 8, - dropout: float = 0.0 + text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0 ) -> ManagedCLIPTextEmbeddings: if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): return text_encoder.text_model.embeddings - + text_embeddings = ManagedCLIPTextEmbeddings( - text_encoder.config, - text_encoder.text_model.embeddings, - alpha, - dropout + text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout ) text_encoder.text_model.embeddings = text_embeddings return text_embeddings diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 789b525..a866641 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -91,18 +91,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): self.vector_shuffle = shuffle_none def add_multi_tokens( - self, - new_tokens: Union[str, list[str]], - num_vectors: Union[int, list[int]] = 1 + self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1 ) -> Union[list[int], list[list[int]]]: if isinstance(new_tokens, list): if isinstance(num_vectors, int): num_vectors = [num_vectors] * len(new_tokens) if len(num_vectors) != len(new_tokens): - raise ValueError("Expected new_tokens and num_vectors to have the same len") + raise ValueError( + "Expected new_tokens and num_vectors to have the same len" + ) - return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] + return [ + self.add_multi_tokens(new_token, vecs) + for new_token, vecs in zip(new_tokens, num_vectors) + ] if isinstance(num_vectors, list): raise ValueError("Expected num_vectors to be int for single token") @@ -129,13 +132,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): return [id] def expand_ids(self, ids: list[int]): - return [ - new_id - for id in ids - for new_id in self.expand_id(id) - ] + return [new_id for id in ids for new_id in self.expand_id(id)] - def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): + def expand_batched_ids( + self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]] + ): if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): return [self.expand_ids(batch) for batch in input_ids] else: diff --git a/models/clip/util.py b/models/clip/util.py index f94fbc7..7196bb6 100644 --- a/models/clip/util.py +++ b/models/clip/util.py @@ -5,27 +5,32 @@ import torch from transformers import CLIPTokenizer, CLIPTextModel -def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): +def unify_input_ids( + tokenizer: CLIPTokenizer, + input_ids: list[list[int]], + max_length: Optional[int] = None, +): if max_length is None: return tokenizer.pad( {"input_ids": input_ids}, padding=True, pad_to_multiple_of=tokenizer.model_max_length, - return_tensors="pt" + return_tensors="pt", ) else: return tokenizer.pad( {"input_ids": input_ids}, padding="max_length", max_length=max_length, - return_tensors="pt" + return_tensors="pt", ) + def get_extended_embeddings( text_encoder: CLIPTextModel, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, - attention_mask=None + attention_mask=None, ): model_max_length = text_encoder.config.max_position_embeddings prompts = input_ids.shape[0] @@ -36,6 +41,8 @@ def get_extended_embeddings( if attention_mask is not None: attention_mask = attention_mask.view((-1, model_max_length)) - text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] + text_embeddings = text_encoder( + input_ids, position_ids=position_ids, attention_mask=attention_mask + )[0] text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) return text_embeddings -- cgit v1.2.3-70-g09d2