From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- models/clip/util.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'models/clip/util.py') diff --git a/models/clip/util.py b/models/clip/util.py index f94fbc7..7196bb6 100644 --- a/models/clip/util.py +++ b/models/clip/util.py @@ -5,27 +5,32 @@ import torch from transformers import CLIPTokenizer, CLIPTextModel -def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): +def unify_input_ids( + tokenizer: CLIPTokenizer, + input_ids: list[list[int]], + max_length: Optional[int] = None, +): if max_length is None: return tokenizer.pad( {"input_ids": input_ids}, padding=True, pad_to_multiple_of=tokenizer.model_max_length, - return_tensors="pt" + return_tensors="pt", ) else: return tokenizer.pad( {"input_ids": input_ids}, padding="max_length", max_length=max_length, - return_tensors="pt" + return_tensors="pt", ) + def get_extended_embeddings( text_encoder: CLIPTextModel, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, - attention_mask=None + attention_mask=None, ): model_max_length = text_encoder.config.max_position_embeddings prompts = input_ids.shape[0] @@ -36,6 +41,8 @@ def get_extended_embeddings( if attention_mask is not None: attention_mask = attention_mask.view((-1, model_max_length)) - text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] + text_embeddings = text_encoder( + input_ids, position_ids=position_ids, attention_mask=attention_mask + )[0] text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) return text_embeddings -- cgit v1.2.3-54-g00ecf