summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/prompt.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index 259ac44..6b6b7e9 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -27,5 +27,6 @@ class PromptProcessor():
27 def get_embeddings(self, input_ids: torch.IntTensor): 27 def get_embeddings(self, input_ids: torch.IntTensor):
28 prompts = input_ids.shape[0] 28 prompts = input_ids.shape[0]
29 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 29 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
30 text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) 30 text_embeddings = self.text_encoder(input_ids)[0]
31 text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2]))
31 return text_embeddings 32 return text_embeddings