summaryrefslogtreecommitdiffstats
path: root/models/clip/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/util.py')
-rw-r--r--models/clip/util.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/models/clip/util.py b/models/clip/util.py
index f94fbc7..7196bb6 100644
--- a/models/clip/util.py
+++ b/models/clip/util.py
@@ -5,27 +5,32 @@ import torch
5from transformers import CLIPTokenizer, CLIPTextModel 5from transformers import CLIPTokenizer, CLIPTextModel
6 6
7 7
8def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): 8def unify_input_ids(
9 tokenizer: CLIPTokenizer,
10 input_ids: list[list[int]],
11 max_length: Optional[int] = None,
12):
9 if max_length is None: 13 if max_length is None:
10 return tokenizer.pad( 14 return tokenizer.pad(
11 {"input_ids": input_ids}, 15 {"input_ids": input_ids},
12 padding=True, 16 padding=True,
13 pad_to_multiple_of=tokenizer.model_max_length, 17 pad_to_multiple_of=tokenizer.model_max_length,
14 return_tensors="pt" 18 return_tensors="pt",
15 ) 19 )
16 else: 20 else:
17 return tokenizer.pad( 21 return tokenizer.pad(
18 {"input_ids": input_ids}, 22 {"input_ids": input_ids},
19 padding="max_length", 23 padding="max_length",
20 max_length=max_length, 24 max_length=max_length,
21 return_tensors="pt" 25 return_tensors="pt",
22 ) 26 )
23 27
28
24def get_extended_embeddings( 29def get_extended_embeddings(
25 text_encoder: CLIPTextModel, 30 text_encoder: CLIPTextModel,
26 input_ids: torch.LongTensor, 31 input_ids: torch.LongTensor,
27 position_ids: Optional[torch.LongTensor] = None, 32 position_ids: Optional[torch.LongTensor] = None,
28 attention_mask=None 33 attention_mask=None,
29): 34):
30 model_max_length = text_encoder.config.max_position_embeddings 35 model_max_length = text_encoder.config.max_position_embeddings
31 prompts = input_ids.shape[0] 36 prompts = input_ids.shape[0]
@@ -36,6 +41,8 @@ def get_extended_embeddings(
36 if attention_mask is not None: 41 if attention_mask is not None:
37 attention_mask = attention_mask.view((-1, model_max_length)) 42 attention_mask = attention_mask.view((-1, model_max_length))
38 43
39 text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] 44 text_embeddings = text_encoder(
45 input_ids, position_ids=position_ids, attention_mask=attention_mask
46 )[0]
40 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) 47 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
41 return text_embeddings 48 return text_embeddings