summaryrefslogtreecommitdiffstats
path: root/models/clip/tokenizer.py
blob: a86664195e49e61807ebaaecc6f27cc148321675 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import copy
from typing import Union, Literal

import numpy as np

from transformers import CLIPTokenizer


def dropout(tokens: list[int], dropout: float):
    if dropout != 0:
        tokens = [token for token in tokens if np.random.random() > dropout]
    return tokens


def shuffle_all(tokens: list[int]):
    if len(tokens) >= 2:
        tokens = copy.copy(tokens)
        np.random.shuffle(tokens)
    return tokens


def shuffle_leading(tokens: list[int]):
    if len(tokens) >= 3:
        subtokens = tokens[:-1]
        np.random.shuffle(subtokens)
        tokens = subtokens + tokens[-1:]
    return tokens


def shuffle_trailing(tokens: list[int]):
    if len(tokens) >= 3:
        subtokens = tokens[1:]
        np.random.shuffle(subtokens)
        tokens = tokens[:1] + subtokens
    return tokens


def shuffle_between(tokens: list[int]):
    if len(tokens) >= 4:
        subtokens = tokens[1:-1]
        np.random.shuffle(subtokens)
        tokens = tokens[:1] + subtokens + tokens[-1:]
    return tokens


def shuffle_none(tokens: list[int]):
    return tokens


def shuffle_auto(tokens: list[int]):
    if len(tokens) >= 5:
        return shuffle_between(tokens)
    if len(tokens) >= 3:
        return shuffle_trailing(tokens)
    return shuffle_all(tokens)


ShuffleAlgorithm = Union[bool, Literal["all", "trailing", "leading", "between", "off"]]


class MultiCLIPTokenizer(CLIPTokenizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.token_map: dict[int, list[int]] = {}
        self.is_training = False
        self.vector_shuffle = shuffle_auto
        self.dropout = 0

    def train(self):
        self.is_training = True

    def eval(self):
        self.is_training = False

    def set_dropout(self, dropout: float):
        self.dropout = dropout

    def set_use_vector_shuffle(self, algorithm: ShuffleAlgorithm):
        if algorithm == "leading":
            self.vector_shuffle = shuffle_leading
        elif algorithm == "trailing":
            self.vector_shuffle = shuffle_trailing
        elif algorithm == "between":
            self.vector_shuffle = shuffle_between
        elif algorithm == "auto":
            self.vector_shuffle = shuffle_auto
        elif algorithm == True or algorithm == "all":
            self.vector_shuffle = shuffle_all
        else:
            self.vector_shuffle = shuffle_none

    def add_multi_tokens(
        self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1
    ) -> Union[list[int], list[list[int]]]:
        if isinstance(new_tokens, list):
            if isinstance(num_vectors, int):
                num_vectors = [num_vectors] * len(new_tokens)

            if len(num_vectors) != len(new_tokens):
                raise ValueError(
                    "Expected new_tokens and num_vectors to have the same len"
                )

            return [
                self.add_multi_tokens(new_token, vecs)
                for new_token, vecs in zip(new_tokens, num_vectors)
            ]

        if isinstance(num_vectors, list):
            raise ValueError("Expected num_vectors to be int for single token")

        if num_vectors < 1:
            raise ValueError("Expected num_vectors to be >= 1")

        tokens = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)]

        super().add_tokens(tokens)
        ids = super().convert_tokens_to_ids(tokens)

        self.token_map[ids[0]] = ids

        return ids

    def expand_id(self, id: int):
        if id in self.token_map:
            ids = self.token_map[id]
            if self.is_training:
                ids = dropout(self.vector_shuffle(ids), self.dropout)
            return ids
        else:
            return [id]

    def expand_ids(self, ids: list[int]):
        return [new_id for id in ids for new_id in self.expand_id(id)]

    def expand_batched_ids(
        self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]
    ):
        if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list):
            return [self.expand_ids(batch) for batch in input_ids]
        else:
            return self.expand_ids(input_ids)

    def _call_one(self, *args, **kwargs):
        result = super()._call_one(*args, **kwargs)
        result.input_ids = self.expand_batched_ids(result.input_ids)
        return result

    def encode(self, *args, **kwargs):
        result = super().encode(*args, **kwargs)
        result = self.expand_batched_ids(result)
        return result