summaryrefslogtreecommitdiffstats
path: root/models/clip/tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/tokenizer.py')
-rw-r--r--models/clip/tokenizer.py9
1 files changed, 2 insertions, 7 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 4e97ab5..034adf9 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -55,11 +55,6 @@ def shuffle_auto(tokens: list[int]):
55 return shuffle_all(tokens) 55 return shuffle_all(tokens)
56 56
57 57
58class MultiCLIPTokenizerItem(NamedTuple):
59 token: str
60 ids: list[int]
61
62
63class MultiCLIPTokenizer(CLIPTokenizer): 58class MultiCLIPTokenizer(CLIPTokenizer):
64 def __init__(self, *args, **kwargs): 59 def __init__(self, *args, **kwargs):
65 super().__init__(*args, **kwargs) 60 super().__init__(*args, **kwargs)
@@ -96,7 +91,7 @@ class MultiCLIPTokenizer(CLIPTokenizer):
96 self, 91 self,
97 new_tokens: Union[str, list[str]], 92 new_tokens: Union[str, list[str]],
98 num_vectors: Union[int, list[int]] = 1 93 num_vectors: Union[int, list[int]] = 1
99 ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]: 94 ) -> Union[list[int], list[list[int]]]:
100 if isinstance(new_tokens, list): 95 if isinstance(new_tokens, list):
101 if isinstance(num_vectors, int): 96 if isinstance(num_vectors, int):
102 num_vectors = [num_vectors] * len(new_tokens) 97 num_vectors = [num_vectors] * len(new_tokens)
@@ -119,7 +114,7 @@ class MultiCLIPTokenizer(CLIPTokenizer):
119 114
120 self.token_map[ids[0]] = ids 115 self.token_map[ids[0]] = ids
121 116
122 return MultiCLIPTokenizerItem(new_tokens, ids) 117 return ids
123 118
124 def expand_id(self, id: int): 119 def expand_id(self, id: int):
125 if id in self.token_map: 120 if id in self.token_map: