From dc463a6b8ef120b7a0643569b66f9109ed38c652 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 14:07:44 +0100 Subject: Simplified multi-vector embedding code --- models/clip/tokenizer.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) (limited to 'models') diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 78871db..7e08287 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -8,8 +8,8 @@ from transformers import CLIPTokenizer class MultiCLIPTokenizerItem(NamedTuple): token: str - placeholder_id: int - multi_ids: list[int] + meta_id: int + ids: list[int] class MultiCLIPTokenizer(CLIPTokenizer): @@ -30,20 +30,19 @@ class MultiCLIPTokenizer(CLIPTokenizer): if isinstance(num_vectors, list): raise ValueError("Expected num_vectors to be int for single token") - super().add_tokens(new_tokens) + if num_vectors < 1: + raise ValueError("Expected num_vectors to be >= 1") - if num_vectors == 1: - multi_token = [new_tokens] - else: - multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] - super().add_tokens(multi_token) + multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] - meta_id = super().convert_tokens_to_ids(new_tokens) - multi_ids = super().convert_tokens_to_ids(multi_token) + super().add_tokens(multi_token) - self.token_map[meta_id] = multi_ids + ids = super().convert_tokens_to_ids(multi_token) + meta_id = ids[0] - return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) + self.token_map[meta_id] = ids + + return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) def encode(self, *args, vector_shuffle=True, **kwargs): ids = super().encode(*args, **kwargs) -- cgit v1.2.3-54-g00ecf