summaryrefslogtreecommitdiffstats
path: root/models/clip/tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/tokenizer.py')
-rw-r--r--models/clip/tokenizer.py64
1 files changed, 64 insertions, 0 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
new file mode 100644
index 0000000..78871db
--- /dev/null
+++ b/models/clip/tokenizer.py
@@ -0,0 +1,64 @@
1import copy
2from typing import NamedTuple, Union
3
4import numpy as np
5
6from transformers import CLIPTokenizer
7
8
9class MultiCLIPTokenizerItem(NamedTuple):
10 token: str
11 placeholder_id: int
12 multi_ids: list[int]
13
14
15class MultiCLIPTokenizer(CLIPTokenizer):
16 def __init__(self, *args, **kwargs):
17 super().__init__(*args, **kwargs)
18 self.token_map: dict[int, list[int]] = {}
19
20 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem:
21 if isinstance(new_tokens, list):
22 if isinstance(num_vectors, int):
23 num_vectors = [num_vectors] * len(new_tokens)
24
25 if len(num_vectors) != len(new_tokens):
26 raise ValueError("Expected new_tokens and num_vectors to have the same len")
27
28 return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)]
29
30 if isinstance(num_vectors, list):
31 raise ValueError("Expected num_vectors to be int for single token")
32
33 super().add_tokens(new_tokens)
34
35 if num_vectors == 1:
36 multi_token = [new_tokens]
37 else:
38 multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)]
39 super().add_tokens(multi_token)
40
41 meta_id = super().convert_tokens_to_ids(new_tokens)
42 multi_ids = super().convert_tokens_to_ids(multi_token)
43
44 self.token_map[meta_id] = multi_ids
45
46 return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids)
47
48 def encode(self, *args, vector_shuffle=True, **kwargs):
49 ids = super().encode(*args, **kwargs)
50 new_ids = []
51
52 for id in ids:
53 if id in self.token_map:
54 tokens = self.token_map[id]
55
56 if vector_shuffle:
57 tokens = copy.copy(tokens)
58 np.random.shuffle(tokens)
59
60 new_ids = new_ids + self.token_map[id]
61 else:
62 new_ids.append(id)
63
64 return new_ids