From b42e7fbc29fd8045a2b932eb8ae76587f51f7513 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 17:12:12 +0100 Subject: Bugfixes for multi-vector token handling --- models/clip/tokenizer.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) (limited to 'models/clip/tokenizer.py') diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 7e08287..63566e0 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -44,20 +44,33 @@ class MultiCLIPTokenizer(CLIPTokenizer): return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) - def encode(self, *args, vector_shuffle=True, **kwargs): - ids = super().encode(*args, **kwargs) - new_ids = [] + def expand_id(self, id: int, vector_shuffle=True): + if id in self.token_map: + tokens = self.token_map[id] - for id in ids: - if id in self.token_map: - tokens = self.token_map[id] + if vector_shuffle: + tokens = copy.copy(tokens) + np.random.shuffle(tokens) - if vector_shuffle: - tokens = copy.copy(tokens) - np.random.shuffle(tokens) + return tokens + else: + return [id] - new_ids = new_ids + self.token_map[id] - else: - new_ids.append(id) + def expand_ids(self, ids: list[int], vector_shuffle=True): + return [ + new_id + for id in ids + for new_id in self.expand_id(id, vector_shuffle) + ] - return new_ids + def _call_one(self, text, *args, vector_shuffle=True, **kwargs): + result = super()._call_one(text, *args, **kwargs) + + is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) + + if is_batched: + result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] + else: + result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) + + return result -- cgit v1.2.3-54-g00ecf