diff options
Diffstat (limited to 'models/clip/tokenizer.py')
-rw-r--r-- | models/clip/tokenizer.py | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py new file mode 100644 index 0000000..78871db --- /dev/null +++ b/models/clip/tokenizer.py | |||
@@ -0,0 +1,64 @@ | |||
1 | import copy | ||
2 | from typing import NamedTuple, Union | ||
3 | |||
4 | import numpy as np | ||
5 | |||
6 | from transformers import CLIPTokenizer | ||
7 | |||
8 | |||
9 | class MultiCLIPTokenizerItem(NamedTuple): | ||
10 | token: str | ||
11 | placeholder_id: int | ||
12 | multi_ids: list[int] | ||
13 | |||
14 | |||
15 | class MultiCLIPTokenizer(CLIPTokenizer): | ||
16 | def __init__(self, *args, **kwargs): | ||
17 | super().__init__(*args, **kwargs) | ||
18 | self.token_map: dict[int, list[int]] = {} | ||
19 | |||
20 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | ||
21 | if isinstance(new_tokens, list): | ||
22 | if isinstance(num_vectors, int): | ||
23 | num_vectors = [num_vectors] * len(new_tokens) | ||
24 | |||
25 | if len(num_vectors) != len(new_tokens): | ||
26 | raise ValueError("Expected new_tokens and num_vectors to have the same len") | ||
27 | |||
28 | return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] | ||
29 | |||
30 | if isinstance(num_vectors, list): | ||
31 | raise ValueError("Expected num_vectors to be int for single token") | ||
32 | |||
33 | super().add_tokens(new_tokens) | ||
34 | |||
35 | if num_vectors == 1: | ||
36 | multi_token = [new_tokens] | ||
37 | else: | ||
38 | multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] | ||
39 | super().add_tokens(multi_token) | ||
40 | |||
41 | meta_id = super().convert_tokens_to_ids(new_tokens) | ||
42 | multi_ids = super().convert_tokens_to_ids(multi_token) | ||
43 | |||
44 | self.token_map[meta_id] = multi_ids | ||
45 | |||
46 | return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) | ||
47 | |||
48 | def encode(self, *args, vector_shuffle=True, **kwargs): | ||
49 | ids = super().encode(*args, **kwargs) | ||
50 | new_ids = [] | ||
51 | |||
52 | for id in ids: | ||
53 | if id in self.token_map: | ||
54 | tokens = self.token_map[id] | ||
55 | |||
56 | if vector_shuffle: | ||
57 | tokens = copy.copy(tokens) | ||
58 | np.random.shuffle(tokens) | ||
59 | |||
60 | new_ids = new_ids + self.token_map[id] | ||
61 | else: | ||
62 | new_ids.append(id) | ||
63 | |||
64 | return new_ids | ||