From 7b149930bb53b93db74106ad20a30abf4b114f9b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 13:49:35 +0100 Subject: Removed PromptProcessor, modularized training loop --- models/clip/prompt.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) delete mode 100644 models/clip/prompt.py (limited to 'models/clip/prompt.py') diff --git a/models/clip/prompt.py b/models/clip/prompt.py deleted file mode 100644 index a7380be..0000000 --- a/models/clip/prompt.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Union, Optional - -import torch - -from transformers import CLIPTokenizer, CLIPTextModel - - -class PromptProcessor(): - def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): - self.tokenizer = tokenizer - self.text_encoder = text_encoder - - def get_input_ids(self, prompt: Union[str, list[str]]): - return self.tokenizer( - prompt, - padding="do_not_pad", - ).input_ids - - def unify_input_ids(self, input_ids: list[list[int]]): - return self.tokenizer.pad( - {"input_ids": input_ids}, - padding=True, - pad_to_multiple_of=self.tokenizer.model_max_length, - return_tensors="pt" - ) - - def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None): - prompts = input_ids.shape[0] - - input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - if position_ids is not None: - position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - if attention_mask is not None: - attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - - text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] - text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) - return text_embeddings -- cgit v1.2.3-54-g00ecf