summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/prompt.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index 9da3955..a7380be 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -1,4 +1,4 @@
1from typing import Union 1from typing import Union, Optional
2 2
3import torch 3import torch
4 4
@@ -16,7 +16,7 @@ class PromptProcessor():
16 padding="do_not_pad", 16 padding="do_not_pad",
17 ).input_ids 17 ).input_ids
18 18
19 def unify_input_ids(self, input_ids: list[int]): 19 def unify_input_ids(self, input_ids: list[list[int]]):
20 return self.tokenizer.pad( 20 return self.tokenizer.pad(
21 {"input_ids": input_ids}, 21 {"input_ids": input_ids},
22 padding=True, 22 padding=True,
@@ -24,13 +24,15 @@ class PromptProcessor():
24 return_tensors="pt" 24 return_tensors="pt"
25 ) 25 )
26 26
27 def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): 27 def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None):
28 prompts = input_ids.shape[0] 28 prompts = input_ids.shape[0]
29 29
30 input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 30 input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
31 if position_ids is not None:
32 position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
31 if attention_mask is not None: 33 if attention_mask is not None:
32 attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 34 attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
33 35
34 text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] 36 text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
35 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) 37 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
36 return text_embeddings 38 return text_embeddings