import copy from typing import NamedTuple, Union import numpy as np from transformers import CLIPTokenizer class MultiCLIPTokenizerItem(NamedTuple): token: str placeholder_id: int multi_ids: list[int] class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.token_map: dict[int, list[int]] = {} def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: if isinstance(new_tokens, list): if isinstance(num_vectors, int): num_vectors = [num_vectors] * len(new_tokens) if len(num_vectors) != len(new_tokens): raise ValueError("Expected new_tokens and num_vectors to have the same len") return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] if isinstance(num_vectors, list): raise ValueError("Expected num_vectors to be int for single token") super().add_tokens(new_tokens) if num_vectors == 1: multi_token = [new_tokens] else: multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] super().add_tokens(multi_token) meta_id = super().convert_tokens_to_ids(new_tokens) multi_ids = super().convert_tokens_to_ids(multi_token) self.token_map[meta_id] = multi_ids return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) def encode(self, *args, vector_shuffle=True, **kwargs): ids = super().encode(*args, **kwargs) new_ids = [] for id in ids: if id in self.token_map: tokens = self.token_map[id] if vector_shuffle: tokens = copy.copy(tokens) np.random.shuffle(tokens) new_ids = new_ids + self.token_map[id] else: new_ids.append(id) return new_ids