From ee9a2777c15d4ceea7ef40802b9a21881f6428a8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Dec 2022 21:15:24 +0100 Subject: Fixed Textual Inversion --- models/clip/prompt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'models/clip') diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 9b427a0..da33ecf 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py @@ -27,10 +27,10 @@ class PromptProcessor(): def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): prompts = input_ids.shape[0] - input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) + input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) if attention_mask is not None: - attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) + attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] - text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) + text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) return text_embeddings -- cgit v1.2.3-70-g09d2