From 3396ca881ed3f3521617cd9024eea56975191d32 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 13:26:32 +0100 Subject: Update --- models/clip/prompt.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'models') diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 9da3955..a7380be 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Optional import torch @@ -16,7 +16,7 @@ class PromptProcessor(): padding="do_not_pad", ).input_ids - def unify_input_ids(self, input_ids: list[int]): + def unify_input_ids(self, input_ids: list[list[int]]): return self.tokenizer.pad( {"input_ids": input_ids}, padding=True, @@ -24,13 +24,15 @@ class PromptProcessor(): return_tensors="pt" ) - def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): + def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None): prompts = input_ids.shape[0] input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) + if position_ids is not None: + position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) if attention_mask is not None: attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] + text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) return text_embeddings -- cgit v1.2.3-54-g00ecf