From 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 12:58:54 +0100 Subject: Added multi-vector embeddings --- models/clip/prompt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'models/clip/prompt.py') diff --git a/models/clip/prompt.py b/models/clip/prompt.py index da33ecf..9da3955 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Union import torch @@ -10,13 +10,13 @@ class PromptProcessor(): self.tokenizer = tokenizer self.text_encoder = text_encoder - def get_input_ids(self, prompt: Union[str, List[str]]): + def get_input_ids(self, prompt: Union[str, list[str]]): return self.tokenizer( prompt, padding="do_not_pad", ).input_ids - def unify_input_ids(self, input_ids: List[int]): + def unify_input_ids(self, input_ids: list[int]): return self.tokenizer.pad( {"input_ids": input_ids}, padding=True, -- cgit v1.2.3-54-g00ecf