from typing import Union import torch from transformers import CLIPTokenizer, CLIPTextModel class PromptProcessor(): def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): self.tokenizer = tokenizer self.text_encoder = text_encoder def get_input_ids(self, prompt: Union[str, list[str]]): return self.tokenizer( prompt, padding="do_not_pad", ).input_ids def unify_input_ids(self, input_ids: list[int]): return self.tokenizer.pad( {"input_ids": input_ids}, padding=True, pad_to_multiple_of=self.tokenizer.model_max_length, return_tensors="pt" ) def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): prompts = input_ids.shape[0] input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) if attention_mask is not None: attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) return text_embeddings