summaryrefslogtreecommitdiffstats
path: root/models/clip/tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/tokenizer.py')
-rw-r--r--models/clip/tokenizer.py23
1 files changed, 12 insertions, 11 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 789b525..a866641 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -91,18 +91,21 @@ class MultiCLIPTokenizer(CLIPTokenizer):
91 self.vector_shuffle = shuffle_none 91 self.vector_shuffle = shuffle_none
92 92
93 def add_multi_tokens( 93 def add_multi_tokens(
94 self, 94 self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1
95 new_tokens: Union[str, list[str]],
96 num_vectors: Union[int, list[int]] = 1
97 ) -> Union[list[int], list[list[int]]]: 95 ) -> Union[list[int], list[list[int]]]:
98 if isinstance(new_tokens, list): 96 if isinstance(new_tokens, list):
99 if isinstance(num_vectors, int): 97 if isinstance(num_vectors, int):
100 num_vectors = [num_vectors] * len(new_tokens) 98 num_vectors = [num_vectors] * len(new_tokens)
101 99
102 if len(num_vectors) != len(new_tokens): 100 if len(num_vectors) != len(new_tokens):
103 raise ValueError("Expected new_tokens and num_vectors to have the same len") 101 raise ValueError(
102 "Expected new_tokens and num_vectors to have the same len"
103 )
104 104
105 return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] 105 return [
106 self.add_multi_tokens(new_token, vecs)
107 for new_token, vecs in zip(new_tokens, num_vectors)
108 ]
106 109
107 if isinstance(num_vectors, list): 110 if isinstance(num_vectors, list):
108 raise ValueError("Expected num_vectors to be int for single token") 111 raise ValueError("Expected num_vectors to be int for single token")
@@ -129,13 +132,11 @@ class MultiCLIPTokenizer(CLIPTokenizer):
129 return [id] 132 return [id]
130 133
131 def expand_ids(self, ids: list[int]): 134 def expand_ids(self, ids: list[int]):
132 return [ 135 return [new_id for id in ids for new_id in self.expand_id(id)]
133 new_id
134 for id in ids
135 for new_id in self.expand_id(id)
136 ]
137 136
138 def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): 137 def expand_batched_ids(
138 self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]
139 ):
139 if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): 140 if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list):
140 return [self.expand_ids(batch) for batch in input_ids] 141 return [self.expand_ids(batch) for batch in input_ids]
141 else: 142 else: