summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/tokenizer.py33
1 files changed, 20 insertions, 13 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index ed9774e..5e33f3e 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -74,7 +74,11 @@ class MultiCLIPTokenizer(CLIPTokenizer):
74 else: 74 else:
75 self.vector_shuffle = shuffle_none 75 self.vector_shuffle = shuffle_none
76 76
77 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: 77 def add_multi_tokens(
78 self,
79 new_tokens: Union[str, list[str]],
80 num_vectors: Union[int, list[int]] = 1
81 ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]:
78 if isinstance(new_tokens, list): 82 if isinstance(new_tokens, list):
79 if isinstance(num_vectors, int): 83 if isinstance(num_vectors, int):
80 num_vectors = [num_vectors] * len(new_tokens) 84 num_vectors = [num_vectors] * len(new_tokens)
@@ -90,11 +94,10 @@ class MultiCLIPTokenizer(CLIPTokenizer):
90 if num_vectors < 1: 94 if num_vectors < 1:
91 raise ValueError("Expected num_vectors to be >= 1") 95 raise ValueError("Expected num_vectors to be >= 1")
92 96
93 multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] 97 tokens = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)]
94 98
95 super().add_tokens(multi_token) 99 super().add_tokens(tokens)
96 100 ids = super().convert_tokens_to_ids(tokens)
97 ids = super().convert_tokens_to_ids(multi_token)
98 101
99 self.token_map[ids[0]] = ids 102 self.token_map[ids[0]] = ids
100 103
@@ -110,14 +113,18 @@ class MultiCLIPTokenizer(CLIPTokenizer):
110 for new_id in self.expand_id(id) 113 for new_id in self.expand_id(id)
111 ] 114 ]
112 115
113 def _call_one(self, text, *args, **kwargs): 116 def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]):
114 result = super()._call_one(text, *args, **kwargs) 117 if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list):
115 118 return [self.expand_ids(batch) for batch in input_ids]
116 is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list)
117
118 if is_batched:
119 result.input_ids = [self.expand_ids(batch) for batch in result.input_ids]
120 else: 119 else:
121 result.input_ids = self.expand_ids(result.input_ids) 120 return self.expand_ids(input_ids)
121
122 def _call_one(self, *args, **kwargs):
123 result = super()._call_one(*args, **kwargs)
124 result.input_ids = self.expand_batched_ids(result.input_ids)
125 return result
122 126
127 def encode(self, *args, **kwargs):
128 result = super().encode(*args, **kwargs)
129 result = self.expand_batched_ids(result)
123 return result 130 return result