From 2847ab264830ab78491a31c168b3a24bfcf66334 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 27 Nov 2022 16:57:58 +0100 Subject: Make prompt processor compatible with any model --- models/clip/prompt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'models') diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 259ac44..6b6b7e9 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py @@ -27,5 +27,6 @@ class PromptProcessor(): def get_embeddings(self, input_ids: torch.IntTensor): prompts = input_ids.shape[0] input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) + text_embeddings = self.text_encoder(input_ids)[0] + text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) return text_embeddings -- cgit v1.2.3-54-g00ecf