From 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 12:58:54 +0100 Subject: Added multi-vector embeddings --- models/clip/tokenizer.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 models/clip/tokenizer.py (limited to 'models/clip/tokenizer.py') diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py new file mode 100644 index 0000000..78871db --- /dev/null +++ b/models/clip/tokenizer.py @@ -0,0 +1,64 @@ +import copy +from typing import NamedTuple, Union + +import numpy as np + +from transformers import CLIPTokenizer + + +class MultiCLIPTokenizerItem(NamedTuple): + token: str + placeholder_id: int + multi_ids: list[int] + + +class MultiCLIPTokenizer(CLIPTokenizer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.token_map: dict[int, list[int]] = {} + + def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: + if isinstance(new_tokens, list): + if isinstance(num_vectors, int): + num_vectors = [num_vectors] * len(new_tokens) + + if len(num_vectors) != len(new_tokens): + raise ValueError("Expected new_tokens and num_vectors to have the same len") + + return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] + + if isinstance(num_vectors, list): + raise ValueError("Expected num_vectors to be int for single token") + + super().add_tokens(new_tokens) + + if num_vectors == 1: + multi_token = [new_tokens] + else: + multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] + super().add_tokens(multi_token) + + meta_id = super().convert_tokens_to_ids(new_tokens) + multi_ids = super().convert_tokens_to_ids(multi_token) + + self.token_map[meta_id] = multi_ids + + return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) + + def encode(self, *args, vector_shuffle=True, **kwargs): + ids = super().encode(*args, **kwargs) + new_ids = [] + + for id in ids: + if id in self.token_map: + tokens = self.token_map[id] + + if vector_shuffle: + tokens = copy.copy(tokens) + np.random.shuffle(tokens) + + new_ids = new_ids + self.token_map[id] + else: + new_ids.append(id) + + return new_ids -- cgit v1.2.3-54-g00ecf