1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
import copy
from typing import NamedTuple, Union, Literal
import numpy as np
from transformers import CLIPTokenizer
def shuffle_all(tokens: list[int]):
if len(tokens) >= 2:
tokens = copy.copy(tokens)
np.random.shuffle(tokens)
return tokens
def shuffle_leading(tokens: list[int]):
if len(tokens) >= 3:
subtokens = tokens[:-1]
np.random.shuffle(subtokens)
tokens = subtokens + tokens[-1:]
return tokens
def shuffle_trailing(tokens: list[int]):
if len(tokens) >= 3:
subtokens = tokens[1:]
np.random.shuffle(subtokens)
tokens = tokens[:1] + subtokens
return tokens
def shuffle_between(tokens: list[int]):
if len(tokens) >= 4:
subtokens = tokens[1:-1]
np.random.shuffle(subtokens)
tokens = tokens[:1] + subtokens + tokens[-1:]
return tokens
def shuffle_none(tokens: list[int]):
return tokens
def shuffle_auto(tokens: list[int]):
if len(tokens) >= 4:
return shuffle_between(tokens)
if len(tokens) >= 3:
return shuffle_trailing(tokens)
return shuffle_all(tokens)
class MultiCLIPTokenizerItem(NamedTuple):
token: str
ids: list[int]
class MultiCLIPTokenizer(CLIPTokenizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token_map: dict[int, list[int]] = {}
self.vector_shuffle = shuffle_none
def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]):
if algorithm == "leading":
self.vector_shuffle = shuffle_leading
elif algorithm == "trailing":
self.vector_shuffle = shuffle_trailing
elif algorithm == "between":
self.vector_shuffle = shuffle_between
elif algorithm == "auto":
self.vector_shuffle = shuffle_auto
elif algorithm == True or algorithm == "all":
self.vector_shuffle = shuffle_all
else:
self.vector_shuffle = shuffle_none
def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem:
if isinstance(new_tokens, list):
if isinstance(num_vectors, int):
num_vectors = [num_vectors] * len(new_tokens)
if len(num_vectors) != len(new_tokens):
raise ValueError("Expected new_tokens and num_vectors to have the same len")
return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)]
if isinstance(num_vectors, list):
raise ValueError("Expected num_vectors to be int for single token")
if num_vectors < 1:
raise ValueError("Expected num_vectors to be >= 1")
multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)]
super().add_tokens(multi_token)
ids = super().convert_tokens_to_ids(multi_token)
self.token_map[ids[0]] = ids
return MultiCLIPTokenizerItem(new_tokens, ids)
def expand_id(self, id: int):
return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id]
def expand_ids(self, ids: list[int]):
return [
new_id
for id in ids
for new_id in self.expand_id(id)
]
def _call_one(self, text, *args, **kwargs):
result = super()._call_one(text, *args, **kwargs)
is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list)
if is_batched:
result.input_ids = [self.expand_ids(batch) for batch in result.input_ids]
else:
result.input_ids = self.expand_ids(result.input_ids)
return result
|