From 646e6ee92f1887a19058953f3eaebfedd4c0df01 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 20:36:50 +0100 Subject: Fix MultiCLIPTokenizer (forgot to override encode) --- models/clip/tokenizer.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) (limited to 'models/clip') diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index ed9774e..5e33f3e 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -74,7 +74,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): else: self.vector_shuffle = shuffle_none - def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: + def add_multi_tokens( + self, + new_tokens: Union[str, list[str]], + num_vectors: Union[int, list[int]] = 1 + ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]: if isinstance(new_tokens, list): if isinstance(num_vectors, int): num_vectors = [num_vectors] * len(new_tokens) @@ -90,11 +94,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): if num_vectors < 1: raise ValueError("Expected num_vectors to be >= 1") - multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] + tokens = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] - super().add_tokens(multi_token) - - ids = super().convert_tokens_to_ids(multi_token) + super().add_tokens(tokens) + ids = super().convert_tokens_to_ids(tokens) self.token_map[ids[0]] = ids @@ -110,14 +113,18 @@ class MultiCLIPTokenizer(CLIPTokenizer): for new_id in self.expand_id(id) ] - def _call_one(self, text, *args, **kwargs): - result = super()._call_one(text, *args, **kwargs) - - is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) - - if is_batched: - result.input_ids = [self.expand_ids(batch) for batch in result.input_ids] + def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): + if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): + return [self.expand_ids(batch) for batch in input_ids] else: - result.input_ids = self.expand_ids(result.input_ids) + return self.expand_ids(input_ids) + + def _call_one(self, *args, **kwargs): + result = super()._call_one(*args, **kwargs) + result.input_ids = self.expand_batched_ids(result.input_ids) + return result + def encode(self, *args, **kwargs): + result = super().encode(*args, **kwargs) + result = self.expand_batched_ids(result) return result -- cgit v1.2.3-54-g00ecf