diff options
| -rw-r--r-- | models/clip/tokenizer.py | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index ed9774e..5e33f3e 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -74,7 +74,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 74 | else: | 74 | else: |
| 75 | self.vector_shuffle = shuffle_none | 75 | self.vector_shuffle = shuffle_none |
| 76 | 76 | ||
| 77 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | 77 | def add_multi_tokens( |
| 78 | self, | ||
| 79 | new_tokens: Union[str, list[str]], | ||
| 80 | num_vectors: Union[int, list[int]] = 1 | ||
| 81 | ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]: | ||
| 78 | if isinstance(new_tokens, list): | 82 | if isinstance(new_tokens, list): |
| 79 | if isinstance(num_vectors, int): | 83 | if isinstance(num_vectors, int): |
| 80 | num_vectors = [num_vectors] * len(new_tokens) | 84 | num_vectors = [num_vectors] * len(new_tokens) |
| @@ -90,11 +94,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 90 | if num_vectors < 1: | 94 | if num_vectors < 1: |
| 91 | raise ValueError("Expected num_vectors to be >= 1") | 95 | raise ValueError("Expected num_vectors to be >= 1") |
| 92 | 96 | ||
| 93 | multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] | 97 | tokens = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] |
| 94 | 98 | ||
| 95 | super().add_tokens(multi_token) | 99 | super().add_tokens(tokens) |
| 96 | 100 | ids = super().convert_tokens_to_ids(tokens) | |
| 97 | ids = super().convert_tokens_to_ids(multi_token) | ||
| 98 | 101 | ||
| 99 | self.token_map[ids[0]] = ids | 102 | self.token_map[ids[0]] = ids |
| 100 | 103 | ||
| @@ -110,14 +113,18 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 110 | for new_id in self.expand_id(id) | 113 | for new_id in self.expand_id(id) |
| 111 | ] | 114 | ] |
| 112 | 115 | ||
| 113 | def _call_one(self, text, *args, **kwargs): | 116 | def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): |
| 114 | result = super()._call_one(text, *args, **kwargs) | 117 | if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): |
| 115 | 118 | return [self.expand_ids(batch) for batch in input_ids] | |
| 116 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) | ||
| 117 | |||
| 118 | if is_batched: | ||
| 119 | result.input_ids = [self.expand_ids(batch) for batch in result.input_ids] | ||
| 120 | else: | 119 | else: |
| 121 | result.input_ids = self.expand_ids(result.input_ids) | 120 | return self.expand_ids(input_ids) |
| 121 | |||
| 122 | def _call_one(self, *args, **kwargs): | ||
| 123 | result = super()._call_one(*args, **kwargs) | ||
| 124 | result.input_ids = self.expand_batched_ids(result.input_ids) | ||
| 125 | return result | ||
| 122 | 126 | ||
| 127 | def encode(self, *args, **kwargs): | ||
| 128 | result = super().encode(*args, **kwargs) | ||
| 129 | result = self.expand_batched_ids(result) | ||
| 123 | return result | 130 | return result |
