diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/prompt.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py new file mode 100644 index 0000000..c1e3340 --- /dev/null +++ b/models/clip/prompt.py | |||
@@ -0,0 +1,31 @@ | |||
1 | from typing import List, Optional, Union | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | ||
6 | |||
7 | |||
8 | class PromptProcessor(): | ||
9 | def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): | ||
10 | self.tokenizer = tokenizer | ||
11 | self.text_encoder = text_encoder | ||
12 | |||
13 | def get_input_ids(self, prompt: Union[str, List[str]]): | ||
14 | return self.tokenizer( | ||
15 | prompt, | ||
16 | padding="do_not_pad", | ||
17 | ).input_ids | ||
18 | |||
19 | def unify_input_ids(self, input_ids: List[int]): | ||
20 | return self.tokenizer.pad( | ||
21 | {"input_ids": input_ids}, | ||
22 | padding=True, | ||
23 | pad_to_multiple_of=self.tokenizer.model_max_length, | ||
24 | return_tensors="pt" | ||
25 | ).input_ids | ||
26 | |||
27 | def get_embeddings(self, input_ids: torch.IntTensor): | ||
28 | prompts = input_ids.shape[0] | ||
29 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
30 | text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) | ||
31 | return text_embeddings | ||