1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
|
import copy
from typing import Union, Literal
import numpy as np
from transformers import CLIPTokenizer
def dropout(tokens: list[int], dropout: float):
if dropout != 0:
tokens = [token for token in tokens if np.random.random() > dropout]
return tokens
def shuffle_all(tokens: list[int]):
if len(tokens) >= 2:
tokens = copy.copy(tokens)
np.random.shuffle(tokens)
return tokens
def shuffle_leading(tokens: list[int]):
if len(tokens) >= 3:
subtokens = tokens[:-1]
np.random.shuffle(subtokens)
tokens = subtokens + tokens[-1:]
return tokens
def shuffle_trailing(tokens: list[int]):
if len(tokens) >= 3:
subtokens = tokens[1:]
np.random.shuffle(subtokens)
tokens = tokens[:1] + subtokens
return tokens
def shuffle_between(tokens: list[int]):
if len(tokens) >= 4:
subtokens = tokens[1:-1]
np.random.shuffle(subtokens)
tokens = tokens[:1] + subtokens + tokens[-1:]
return tokens
def shuffle_none(tokens: list[int]):
return tokens
def shuffle_auto(tokens: list[int]):
if len(tokens) >= 5:
return shuffle_between(tokens)
if len(tokens) >= 3:
return shuffle_trailing(tokens)
return shuffle_all(tokens)
class MultiCLIPTokenizer(CLIPTokenizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token_map: dict[int, list[int]] = {}
self.is_training = False
self.vector_shuffle = shuffle_auto
self.dropout = 0
def train(self):
self.is_training = True
def eval(self):
self.is_training = False
def set_dropout(self, dropout: float):
self.dropout = dropout
def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]):
if algorithm == "leading":
self.vector_shuffle = shuffle_leading
elif algorithm == "trailing":
self.vector_shuffle = shuffle_trailing
elif algorithm == "between":
self.vector_shuffle = shuffle_between
elif algorithm == "auto":
self.vector_shuffle = shuffle_auto
elif algorithm == True or algorithm == "all":
self.vector_shuffle = shuffle_all
else:
self.vector_shuffle = shuffle_none
def add_multi_tokens(
self,
new_tokens: Union[str, list[str]],
num_vectors: Union[int, list[int]] = 1
) -> Union[list[int], list[list[int]]]:
if isinstance(new_tokens, list):
if isinstance(num_vectors, int):
num_vectors = [num_vectors] * len(new_tokens)
if len(num_vectors) != len(new_tokens):
raise ValueError("Expected new_tokens and num_vectors to have the same len")
return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)]
if isinstance(num_vectors, list):
raise ValueError("Expected num_vectors to be int for single token")
if num_vectors < 1:
raise ValueError("Expected num_vectors to be >= 1")
tokens = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)]
super().add_tokens(tokens)
ids = super().convert_tokens_to_ids(tokens)
self.token_map[ids[0]] = ids
return ids
def expand_id(self, id: int):
if id in self.token_map:
ids = self.token_map[id]
if self.is_training:
ids = dropout(self.vector_shuffle(ids), self.dropout)
return ids
else:
return [id]
def expand_ids(self, ids: list[int]):
return [
new_id
for id in ids
for new_id in self.expand_id(id)
]
def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]):
if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list):
return [self.expand_ids(batch) for batch in input_ids]
else:
return self.expand_ids(input_ids)
def _call_one(self, *args, **kwargs):
result = super()._call_one(*args, **kwargs)
result.input_ids = self.expand_batched_ids(result.input_ids)
return result
def encode(self, *args, **kwargs):
result = super().encode(*args, **kwargs)
result = self.expand_batched_ids(result)
return result
|