from typing import Optional import torch from transformers import CLIPTokenizer, CLIPTextModel def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): return tokenizer.pad( {"input_ids": input_ids}, padding=True, pad_to_multiple_of=tokenizer.model_max_length, return_tensors="pt" ) def get_extended_embeddings( text_encoder: CLIPTextModel, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None ): model_max_length = text_encoder.config.max_position_embeddings prompts = input_ids.shape[0] input_ids = input_ids.view((-1, model_max_length)) if position_ids is not None: position_ids = position_ids.view((-1, model_max_length)) if attention_mask is not None: attention_mask = attention_mask.view((-1, model_max_length)) text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) return text_embeddings