From 358874cd2c49cb55676af86d2950b86d9ccb023a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 10 Dec 2022 13:12:37 +0100 Subject: Support attention_mask of text encoder --- models/clip/prompt.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'models/clip') diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 6b6b7e9..9b427a0 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py @@ -22,11 +22,15 @@ class PromptProcessor(): padding=True, pad_to_multiple_of=self.tokenizer.model_max_length, return_tensors="pt" - ).input_ids + ) - def get_embeddings(self, input_ids: torch.IntTensor): + def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): prompts = input_ids.shape[0] + input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - text_embeddings = self.text_encoder(input_ids)[0] + if attention_mask is not None: + attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) + + text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) return text_embeddings -- cgit v1.2.3-70-g09d2