summaryrefslogtreecommitdiffstats
path: root/models/clip/tokenizer.py
blob: 5e33f3ed7e20031a903411f1392c28dfae05aa5a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import copy
from typing import NamedTuple, Union, Literal

import numpy as np

from transformers import CLIPTokenizer


def shuffle_all(tokens: list[int]):
    if len(tokens) >= 2:
        tokens = copy.copy(tokens)
        np.random.shuffle(tokens)
    return tokens


def shuffle_leading(tokens: list[int]):
    if len(tokens) >= 3:
        subtokens = tokens[:-1]
        np.random.shuffle(subtokens)
        tokens = subtokens + tokens[-1:]
    return tokens


def shuffle_trailing(tokens: list[int]):
    if len(tokens) >= 3:
        subtokens = tokens[1:]
        np.random.shuffle(subtokens)
        tokens = tokens[:1] + subtokens
    return tokens


def shuffle_between(tokens: list[int]):
    if len(tokens) >= 4:
        subtokens = tokens[1:-1]
        np.random.shuffle(subtokens)
        tokens = tokens[:1] + subtokens + tokens[-1:]
    return tokens


def shuffle_none(tokens: list[int]):
    return tokens


def shuffle_auto(tokens: list[int]):
    if len(tokens) >= 4:
        return shuffle_between(tokens)
    if len(tokens) >= 3:
        return shuffle_trailing(tokens)
    return shuffle_all(tokens)


class MultiCLIPTokenizerItem(NamedTuple):
    token: str
    ids: list[int]


class MultiCLIPTokenizer(CLIPTokenizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.token_map: dict[int, list[int]] = {}
        self.vector_shuffle = shuffle_none

    def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]):
        if algorithm == "leading":
            self.vector_shuffle = shuffle_leading
        elif algorithm == "trailing":
            self.vector_shuffle = shuffle_trailing
        elif algorithm == "between":
            self.vector_shuffle = shuffle_between
        elif algorithm == "auto":
            self.vector_shuffle = shuffle_auto
        elif algorithm == True or algorithm == "all":
            self.vector_shuffle = shuffle_all
        else:
            self.vector_shuffle = shuffle_none

    def add_multi_tokens(
        self,
        new_tokens: Union[str, list[str]],
        num_vectors: Union[int, list[int]] = 1
    ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]:
        if isinstance(new_tokens, list):
            if isinstance(num_vectors, int):
                num_vectors = [num_vectors] * len(new_tokens)

            if len(num_vectors) != len(new_tokens):
                raise ValueError("Expected new_tokens and num_vectors to have the same len")

            return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)]

        if isinstance(num_vectors, list):
            raise ValueError("Expected num_vectors to be int for single token")

        if num_vectors < 1:
            raise ValueError("Expected num_vectors to be >= 1")

        tokens = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)]

        super().add_tokens(tokens)
        ids = super().convert_tokens_to_ids(tokens)

        self.token_map[ids[0]] = ids

        return MultiCLIPTokenizerItem(new_tokens, ids)

    def expand_id(self, id: int):
        return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id]

    def expand_ids(self, ids: list[int]):
        return [
            new_id
            for id in ids
            for new_id in self.expand_id(id)
        ]

    def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]):
        if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list):
            return [self.expand_ids(batch) for batch in input_ids]
        else:
            return self.expand_ids(input_ids)

    def _call_one(self, *args, **kwargs):
        result = super()._call_one(*args, **kwargs)
        result.input_ids = self.expand_batched_ids(result.input_ids)
        return result

    def encode(self, *args, **kwargs):
        result = super().encode(*args, **kwargs)
        result = self.expand_batched_ids(result)
        return result