import copy from typing import Union, Literal import numpy as np from transformers import CLIPTokenizer def dropout(tokens: list[int], dropout: float): if dropout != 0: tokens = [token for token in tokens if np.random.random() > dropout] return tokens def shuffle_all(tokens: list[int]): if len(tokens) >= 2: tokens = copy.copy(tokens) np.random.shuffle(tokens) return tokens def shuffle_leading(tokens: list[int]): if len(tokens) >= 3: subtokens = tokens[:-1] np.random.shuffle(subtokens) tokens = subtokens + tokens[-1:] return tokens def shuffle_trailing(tokens: list[int]): if len(tokens) >= 3: subtokens = tokens[1:] np.random.shuffle(subtokens) tokens = tokens[:1] + subtokens return tokens def shuffle_between(tokens: list[int]): if len(tokens) >= 4: subtokens = tokens[1:-1] np.random.shuffle(subtokens) tokens = tokens[:1] + subtokens + tokens[-1:] return tokens def shuffle_none(tokens: list[int]): return tokens def shuffle_auto(tokens: list[int]): if len(tokens) >= 5: return shuffle_between(tokens) if len(tokens) >= 3: return shuffle_trailing(tokens) return shuffle_all(tokens) ShuffleAlgorithm = Union[bool, Literal["all", "trailing", "leading", "between", "off"]] class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.token_map: dict[int, list[int]] = {} self.is_training = False self.vector_shuffle = shuffle_auto self.dropout = 0 def train(self): self.is_training = True def eval(self): self.is_training = False def set_dropout(self, dropout: float): self.dropout = dropout def set_use_vector_shuffle(self, algorithm: ShuffleAlgorithm): if algorithm == "leading": self.vector_shuffle = shuffle_leading elif algorithm == "trailing": self.vector_shuffle = shuffle_trailing elif algorithm == "between": self.vector_shuffle = shuffle_between elif algorithm == "auto": self.vector_shuffle = shuffle_auto elif algorithm == True or algorithm == "all": self.vector_shuffle = shuffle_all else: self.vector_shuffle = shuffle_none def add_multi_tokens( self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1 ) -> Union[list[int], list[list[int]]]: if isinstance(new_tokens, list): if isinstance(num_vectors, int): num_vectors = [num_vectors] * len(new_tokens) if len(num_vectors) != len(new_tokens): raise ValueError("Expected new_tokens and num_vectors to have the same len") return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] if isinstance(num_vectors, list): raise ValueError("Expected num_vectors to be int for single token") if num_vectors < 1: raise ValueError("Expected num_vectors to be >= 1") tokens = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] super().add_tokens(tokens) ids = super().convert_tokens_to_ids(tokens) self.token_map[ids[0]] = ids return ids def expand_id(self, id: int): if id in self.token_map: ids = self.token_map[id] if self.is_training: ids = dropout(self.vector_shuffle(ids), self.dropout) return ids else: return [id] def expand_ids(self, ids: list[int]): return [ new_id for id in ids for new_id in self.expand_id(id) ] def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): return [self.expand_ids(batch) for batch in input_ids] else: return self.expand_ids(input_ids) def _call_one(self, *args, **kwargs): result = super()._call_one(*args, **kwargs) result.input_ids = self.expand_batched_ids(result.input_ids) return result def encode(self, *args, **kwargs): result = super().encode(*args, **kwargs) result = self.expand_batched_ids(result) return result