diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/tokenizer.py | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 78871db..7e08287 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -8,8 +8,8 @@ from transformers import CLIPTokenizer | |||
8 | 8 | ||
9 | class MultiCLIPTokenizerItem(NamedTuple): | 9 | class MultiCLIPTokenizerItem(NamedTuple): |
10 | token: str | 10 | token: str |
11 | placeholder_id: int | 11 | meta_id: int |
12 | multi_ids: list[int] | 12 | ids: list[int] |
13 | 13 | ||
14 | 14 | ||
15 | class MultiCLIPTokenizer(CLIPTokenizer): | 15 | class MultiCLIPTokenizer(CLIPTokenizer): |
@@ -30,20 +30,19 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
30 | if isinstance(num_vectors, list): | 30 | if isinstance(num_vectors, list): |
31 | raise ValueError("Expected num_vectors to be int for single token") | 31 | raise ValueError("Expected num_vectors to be int for single token") |
32 | 32 | ||
33 | super().add_tokens(new_tokens) | 33 | if num_vectors < 1: |
34 | raise ValueError("Expected num_vectors to be >= 1") | ||
34 | 35 | ||
35 | if num_vectors == 1: | 36 | multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] |
36 | multi_token = [new_tokens] | ||
37 | else: | ||
38 | multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] | ||
39 | super().add_tokens(multi_token) | ||
40 | 37 | ||
41 | meta_id = super().convert_tokens_to_ids(new_tokens) | 38 | super().add_tokens(multi_token) |
42 | multi_ids = super().convert_tokens_to_ids(multi_token) | ||
43 | 39 | ||
44 | self.token_map[meta_id] = multi_ids | 40 | ids = super().convert_tokens_to_ids(multi_token) |
41 | meta_id = ids[0] | ||
45 | 42 | ||
46 | return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) | 43 | self.token_map[meta_id] = ids |
44 | |||
45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) | ||
47 | 46 | ||
48 | def encode(self, *args, vector_shuffle=True, **kwargs): | 47 | def encode(self, *args, vector_shuffle=True, **kwargs): |
49 | ids = super().encode(*args, **kwargs) | 48 | ids = super().encode(*args, **kwargs) |