summaryrefslogtreecommitdiffstats
path: root/models/clip/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/util.py')
-rw-r--r--models/clip/util.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/models/clip/util.py b/models/clip/util.py
new file mode 100644
index 0000000..8de8c19
--- /dev/null
+++ b/models/clip/util.py
@@ -0,0 +1,34 @@
1from typing import Optional
2
3import torch
4
5from transformers import CLIPTokenizer, CLIPTextModel
6
7
8def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]):
9 return tokenizer.pad(
10 {"input_ids": input_ids},
11 padding=True,
12 pad_to_multiple_of=tokenizer.model_max_length,
13 return_tensors="pt"
14 )
15
16
17def get_extended_embeddings(
18 text_encoder: CLIPTextModel,
19 input_ids: torch.LongTensor,
20 position_ids: Optional[torch.LongTensor] = None,
21 attention_mask=None
22):
23 model_max_length = text_encoder.config.max_position_embeddings
24 prompts = input_ids.shape[0]
25
26 input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device)
27 if position_ids is not None:
28 position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device)
29 if attention_mask is not None:
30 attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device)
31
32 text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
33 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
34 return text_embeddings