diff options
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/tokenizer.py | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 78871db..7e08287 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -8,8 +8,8 @@ from transformers import CLIPTokenizer | |||
| 8 | 8 | ||
| 9 | class MultiCLIPTokenizerItem(NamedTuple): | 9 | class MultiCLIPTokenizerItem(NamedTuple): |
| 10 | token: str | 10 | token: str |
| 11 | placeholder_id: int | 11 | meta_id: int |
| 12 | multi_ids: list[int] | 12 | ids: list[int] |
| 13 | 13 | ||
| 14 | 14 | ||
| 15 | class MultiCLIPTokenizer(CLIPTokenizer): | 15 | class MultiCLIPTokenizer(CLIPTokenizer): |
| @@ -30,20 +30,19 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 30 | if isinstance(num_vectors, list): | 30 | if isinstance(num_vectors, list): |
| 31 | raise ValueError("Expected num_vectors to be int for single token") | 31 | raise ValueError("Expected num_vectors to be int for single token") |
| 32 | 32 | ||
| 33 | super().add_tokens(new_tokens) | 33 | if num_vectors < 1: |
| 34 | raise ValueError("Expected num_vectors to be >= 1") | ||
| 34 | 35 | ||
| 35 | if num_vectors == 1: | 36 | multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] |
| 36 | multi_token = [new_tokens] | ||
| 37 | else: | ||
| 38 | multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] | ||
| 39 | super().add_tokens(multi_token) | ||
| 40 | 37 | ||
| 41 | meta_id = super().convert_tokens_to_ids(new_tokens) | 38 | super().add_tokens(multi_token) |
| 42 | multi_ids = super().convert_tokens_to_ids(multi_token) | ||
| 43 | 39 | ||
| 44 | self.token_map[meta_id] = multi_ids | 40 | ids = super().convert_tokens_to_ids(multi_token) |
| 41 | meta_id = ids[0] | ||
| 45 | 42 | ||
| 46 | return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) | 43 | self.token_map[meta_id] = ids |
| 44 | |||
| 45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) | ||
| 47 | 46 | ||
| 48 | def encode(self, *args, vector_shuffle=True, **kwargs): | 47 | def encode(self, *args, vector_shuffle=True, **kwargs): |
| 49 | ids = super().encode(*args, **kwargs) | 48 | ids = super().encode(*args, **kwargs) |
