from typing import List, Optional, Union import torch from transformers import CLIPTokenizer, CLIPTextModel class PromptProcessor(): def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): self.tokenizer = tokenizer self.text_encoder = text_encoder def get_input_ids(self, prompt: Union[str, List[str]]): return self.tokenizer( prompt, padding="do_not_pad", ).input_ids def unify_input_ids(self, input_ids: List[int]): return self.tokenizer.pad( {"input_ids": input_ids}, padding=True, pad_to_multiple_of=self.tokenizer.model_max_length, return_tensors="pt" ).input_ids def get_embeddings(self, input_ids: torch.IntTensor): prompts = input_ids.shape[0] input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) return text_embeddings