1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
from typing import List, Union
import torch
from transformers import CLIPTokenizer, CLIPTextModel
class PromptProcessor():
def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
def get_input_ids(self, prompt: Union[str, List[str]]):
return self.tokenizer(
prompt,
padding="do_not_pad",
).input_ids
def unify_input_ids(self, input_ids: List[int]):
return self.tokenizer.pad(
{"input_ids": input_ids},
padding=True,
pad_to_multiple_of=self.tokenizer.model_max_length,
return_tensors="pt"
)
def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None):
prompts = input_ids.shape[0]
input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
if attention_mask is not None:
attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0]
text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
return text_embeddings
|