From 306f2bfb620e6882737658bd3694c79365d75e4b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 18 Oct 2022 15:23:40 +0200 Subject: Improved prompt handling --- models/clip/prompt.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 models/clip/prompt.py (limited to 'models') diff --git a/models/clip/prompt.py b/models/clip/prompt.py new file mode 100644 index 0000000..c1e3340 --- /dev/null +++ b/models/clip/prompt.py @@ -0,0 +1,31 @@ +from typing import List, Optional, Union + +import torch + +from transformers import CLIPTokenizer, CLIPTextModel + + +class PromptProcessor(): + def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): + self.tokenizer = tokenizer + self.text_encoder = text_encoder + + def get_input_ids(self, prompt: Union[str, List[str]]): + return self.tokenizer( + prompt, + padding="do_not_pad", + ).input_ids + + def unify_input_ids(self, input_ids: List[int]): + return self.tokenizer.pad( + {"input_ids": input_ids}, + padding=True, + pad_to_multiple_of=self.tokenizer.model_max_length, + return_tensors="pt" + ).input_ids + + def get_embeddings(self, input_ids: torch.IntTensor): + prompts = input_ids.shape[0] + input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) + text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) + return text_embeddings -- cgit v1.2.3-54-g00ecf