import copy from typing import NamedTuple, Union import numpy as np from transformers import CLIPTokenizer class MultiCLIPTokenizerItem(NamedTuple): token: str ids: list[int] class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.token_map: dict[int, list[int]] = {} self.vector_shuffle = False def set_use_vector_shuffle(self, enable: bool): self.vector_shuffle = enable def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: if isinstance(new_tokens, list): if isinstance(num_vectors, int): num_vectors = [num_vectors] * len(new_tokens) if len(num_vectors) != len(new_tokens): raise ValueError("Expected new_tokens and num_vectors to have the same len") return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] if isinstance(num_vectors, list): raise ValueError("Expected num_vectors to be int for single token") if num_vectors < 1: raise ValueError("Expected num_vectors to be >= 1") multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] super().add_tokens(multi_token) ids = super().convert_tokens_to_ids(multi_token) self.token_map[ids[0]] = ids return MultiCLIPTokenizerItem(new_tokens, ids) def expand_id(self, id: int): if id in self.token_map: tokens = self.token_map[id] if self.vector_shuffle and len(tokens) > 2: subtokens = tokens[1:-1] np.random.shuffle(subtokens) tokens = tokens[:1] + subtokens + tokens[-1:] return tokens else: return [id] def expand_ids(self, ids: list[int]): return [ new_id for id in ids for new_id in self.expand_id(id) ] def _call_one(self, text, *args, **kwargs): result = super()._call_one(text, *args, **kwargs) is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) if is_batched: result.input_ids = [self.expand_ids(batch) for batch in result.input_ids] else: result.input_ids = self.expand_ids(result.input_ids) return result