diff options
Diffstat (limited to 'models/clip/util.py')
-rw-r--r-- | models/clip/util.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/models/clip/util.py b/models/clip/util.py new file mode 100644 index 0000000..8de8c19 --- /dev/null +++ b/models/clip/util.py | |||
@@ -0,0 +1,34 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | ||
6 | |||
7 | |||
8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): | ||
9 | return tokenizer.pad( | ||
10 | {"input_ids": input_ids}, | ||
11 | padding=True, | ||
12 | pad_to_multiple_of=tokenizer.model_max_length, | ||
13 | return_tensors="pt" | ||
14 | ) | ||
15 | |||
16 | |||
17 | def get_extended_embeddings( | ||
18 | text_encoder: CLIPTextModel, | ||
19 | input_ids: torch.LongTensor, | ||
20 | position_ids: Optional[torch.LongTensor] = None, | ||
21 | attention_mask=None | ||
22 | ): | ||
23 | model_max_length = text_encoder.config.max_position_embeddings | ||
24 | prompts = input_ids.shape[0] | ||
25 | |||
26 | input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) | ||
27 | if position_ids is not None: | ||
28 | position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) | ||
29 | if attention_mask is not None: | ||
30 | attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) | ||
31 | |||
32 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | ||
33 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | ||
34 | return text_embeddings | ||