summaryrefslogtreecommitdiffstats
path: root/models/clip/prompt.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/prompt.py')
-rw-r--r--models/clip/prompt.py38
1 files changed, 0 insertions, 38 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
deleted file mode 100644
index a7380be..0000000
--- a/models/clip/prompt.py
+++ /dev/null
@@ -1,38 +0,0 @@
1from typing import Union, Optional
2
3import torch
4
5from transformers import CLIPTokenizer, CLIPTextModel
6
7
8class PromptProcessor():
9 def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel):
10 self.tokenizer = tokenizer
11 self.text_encoder = text_encoder
12
13 def get_input_ids(self, prompt: Union[str, list[str]]):
14 return self.tokenizer(
15 prompt,
16 padding="do_not_pad",
17 ).input_ids
18
19 def unify_input_ids(self, input_ids: list[list[int]]):
20 return self.tokenizer.pad(
21 {"input_ids": input_ids},
22 padding=True,
23 pad_to_multiple_of=self.tokenizer.model_max_length,
24 return_tensors="pt"
25 )
26
27 def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None):
28 prompts = input_ids.shape[0]
29
30 input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
31 if position_ids is not None:
32 position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
33 if attention_mask is not None:
34 attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
35
36 text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
37 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
38 return text_embeddings