diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/tokenizer.py | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index ed9774e..5e33f3e 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -74,7 +74,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
74 | else: | 74 | else: |
75 | self.vector_shuffle = shuffle_none | 75 | self.vector_shuffle = shuffle_none |
76 | 76 | ||
77 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | 77 | def add_multi_tokens( |
78 | self, | ||
79 | new_tokens: Union[str, list[str]], | ||
80 | num_vectors: Union[int, list[int]] = 1 | ||
81 | ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]: | ||
78 | if isinstance(new_tokens, list): | 82 | if isinstance(new_tokens, list): |
79 | if isinstance(num_vectors, int): | 83 | if isinstance(num_vectors, int): |
80 | num_vectors = [num_vectors] * len(new_tokens) | 84 | num_vectors = [num_vectors] * len(new_tokens) |
@@ -90,11 +94,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
90 | if num_vectors < 1: | 94 | if num_vectors < 1: |
91 | raise ValueError("Expected num_vectors to be >= 1") | 95 | raise ValueError("Expected num_vectors to be >= 1") |
92 | 96 | ||
93 | multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] | 97 | tokens = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] |
94 | 98 | ||
95 | super().add_tokens(multi_token) | 99 | super().add_tokens(tokens) |
96 | 100 | ids = super().convert_tokens_to_ids(tokens) | |
97 | ids = super().convert_tokens_to_ids(multi_token) | ||
98 | 101 | ||
99 | self.token_map[ids[0]] = ids | 102 | self.token_map[ids[0]] = ids |
100 | 103 | ||
@@ -110,14 +113,18 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
110 | for new_id in self.expand_id(id) | 113 | for new_id in self.expand_id(id) |
111 | ] | 114 | ] |
112 | 115 | ||
113 | def _call_one(self, text, *args, **kwargs): | 116 | def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): |
114 | result = super()._call_one(text, *args, **kwargs) | 117 | if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): |
115 | 118 | return [self.expand_ids(batch) for batch in input_ids] | |
116 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) | ||
117 | |||
118 | if is_batched: | ||
119 | result.input_ids = [self.expand_ids(batch) for batch in result.input_ids] | ||
120 | else: | 119 | else: |
121 | result.input_ids = self.expand_ids(result.input_ids) | 120 | return self.expand_ids(input_ids) |
121 | |||
122 | def _call_one(self, *args, **kwargs): | ||
123 | result = super()._call_one(*args, **kwargs) | ||
124 | result.input_ids = self.expand_batched_ids(result.input_ids) | ||
125 | return result | ||
122 | 126 | ||
127 | def encode(self, *args, **kwargs): | ||
128 | result = super().encode(*args, **kwargs) | ||
129 | result = self.expand_batched_ids(result) | ||
123 | return result | 130 | return result |