diff options
Diffstat (limited to 'models/clip/tokenizer.py')
| -rw-r--r-- | models/clip/tokenizer.py | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py new file mode 100644 index 0000000..78871db --- /dev/null +++ b/models/clip/tokenizer.py | |||
| @@ -0,0 +1,64 @@ | |||
| 1 | import copy | ||
| 2 | from typing import NamedTuple, Union | ||
| 3 | |||
| 4 | import numpy as np | ||
| 5 | |||
| 6 | from transformers import CLIPTokenizer | ||
| 7 | |||
| 8 | |||
| 9 | class MultiCLIPTokenizerItem(NamedTuple): | ||
| 10 | token: str | ||
| 11 | placeholder_id: int | ||
| 12 | multi_ids: list[int] | ||
| 13 | |||
| 14 | |||
| 15 | class MultiCLIPTokenizer(CLIPTokenizer): | ||
| 16 | def __init__(self, *args, **kwargs): | ||
| 17 | super().__init__(*args, **kwargs) | ||
| 18 | self.token_map: dict[int, list[int]] = {} | ||
| 19 | |||
| 20 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | ||
| 21 | if isinstance(new_tokens, list): | ||
| 22 | if isinstance(num_vectors, int): | ||
| 23 | num_vectors = [num_vectors] * len(new_tokens) | ||
| 24 | |||
| 25 | if len(num_vectors) != len(new_tokens): | ||
| 26 | raise ValueError("Expected new_tokens and num_vectors to have the same len") | ||
| 27 | |||
| 28 | return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] | ||
| 29 | |||
| 30 | if isinstance(num_vectors, list): | ||
| 31 | raise ValueError("Expected num_vectors to be int for single token") | ||
| 32 | |||
| 33 | super().add_tokens(new_tokens) | ||
| 34 | |||
| 35 | if num_vectors == 1: | ||
| 36 | multi_token = [new_tokens] | ||
| 37 | else: | ||
| 38 | multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] | ||
| 39 | super().add_tokens(multi_token) | ||
| 40 | |||
| 41 | meta_id = super().convert_tokens_to_ids(new_tokens) | ||
| 42 | multi_ids = super().convert_tokens_to_ids(multi_token) | ||
| 43 | |||
| 44 | self.token_map[meta_id] = multi_ids | ||
| 45 | |||
| 46 | return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) | ||
| 47 | |||
| 48 | def encode(self, *args, vector_shuffle=True, **kwargs): | ||
| 49 | ids = super().encode(*args, **kwargs) | ||
| 50 | new_ids = [] | ||
| 51 | |||
| 52 | for id in ids: | ||
| 53 | if id in self.token_map: | ||
| 54 | tokens = self.token_map[id] | ||
| 55 | |||
| 56 | if vector_shuffle: | ||
| 57 | tokens = copy.copy(tokens) | ||
| 58 | np.random.shuffle(tokens) | ||
| 59 | |||
| 60 | new_ids = new_ids + self.token_map[id] | ||
| 61 | else: | ||
| 62 | new_ids.append(id) | ||
| 63 | |||
| 64 | return new_ids | ||
