from typing import Optional import torch from transformers import CLIPTokenizer, CLIPTextModel def unify_input_ids( tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None, ): if max_length is None: return tokenizer.pad( {"input_ids": input_ids}, padding=True, pad_to_multiple_of=tokenizer.model_max_length, return_tensors="pt", ) else: return tokenizer.pad( {"input_ids": input_ids}, padding="max_length", max_length=max_length, return_tensors="pt", ) def get_extended_embeddings( text_encoder: CLIPTextModel, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None, ): model_max_length = text_encoder.config.max_position_embeddings prompts = input_ids.shape[0] input_ids = input_ids.view((-1, model_max_length)) if position_ids is not None: position_ids = position_ids.view((-1, model_max_length)) if attention_mask is not None: attention_mask = attention_mask.view((-1, model_max_length)) text_embeddings = text_encoder( input_ids, position_ids=position_ids, attention_mask=attention_mask )[0] text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) return text_embeddings