diff options
-rw-r--r-- | models/clip/prompt.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 259ac44..6b6b7e9 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
@@ -27,5 +27,6 @@ class PromptProcessor(): | |||
27 | def get_embeddings(self, input_ids: torch.IntTensor): | 27 | def get_embeddings(self, input_ids: torch.IntTensor): |
28 | prompts = input_ids.shape[0] | 28 | prompts = input_ids.shape[0] |
29 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 29 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
30 | text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) | 30 | text_embeddings = self.text_encoder(input_ids)[0] |
31 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) | ||
31 | return text_embeddings | 32 | return text_embeddings |