diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/tokenizer.py | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index a3e6e70..37d69a9 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -15,6 +15,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
15 | def __init__(self, *args, **kwargs): | 15 | def __init__(self, *args, **kwargs): |
16 | super().__init__(*args, **kwargs) | 16 | super().__init__(*args, **kwargs) |
17 | self.token_map: dict[int, list[int]] = {} | 17 | self.token_map: dict[int, list[int]] = {} |
18 | self.vector_shuffle = False | ||
19 | |||
20 | def set_use_vector_shuffle(self, enable: bool): | ||
21 | self.vector_shuffle = enable | ||
18 | 22 | ||
19 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | 23 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: |
20 | if isinstance(new_tokens, list): | 24 | if isinstance(new_tokens, list): |
@@ -42,11 +46,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
42 | 46 | ||
43 | return MultiCLIPTokenizerItem(new_tokens, ids) | 47 | return MultiCLIPTokenizerItem(new_tokens, ids) |
44 | 48 | ||
45 | def expand_id(self, id: int, vector_shuffle=True): | 49 | def expand_id(self, id: int): |
46 | if id in self.token_map: | 50 | if id in self.token_map: |
47 | tokens = self.token_map[id] | 51 | tokens = self.token_map[id] |
48 | 52 | ||
49 | if vector_shuffle and len(tokens) > 2: | 53 | if self.vector_shuffle and len(tokens) > 2: |
50 | subtokens = tokens[1:-1] | 54 | subtokens = tokens[1:-1] |
51 | np.random.shuffle(subtokens) | 55 | np.random.shuffle(subtokens) |
52 | tokens = tokens[:1] + subtokens + tokens[-1:] | 56 | tokens = tokens[:1] + subtokens + tokens[-1:] |
@@ -55,21 +59,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
55 | else: | 59 | else: |
56 | return [id] | 60 | return [id] |
57 | 61 | ||
58 | def expand_ids(self, ids: list[int], vector_shuffle=True): | 62 | def expand_ids(self, ids: list[int]): |
59 | return [ | 63 | return [ |
60 | new_id | 64 | new_id |
61 | for id in ids | 65 | for id in ids |
62 | for new_id in self.expand_id(id, vector_shuffle) | 66 | for new_id in self.expand_id(id) |
63 | ] | 67 | ] |
64 | 68 | ||
65 | def _call_one(self, text, *args, vector_shuffle=True, **kwargs): | 69 | def _call_one(self, text, *args, **kwargs): |
66 | result = super()._call_one(text, *args, **kwargs) | 70 | result = super()._call_one(text, *args, **kwargs) |
67 | 71 | ||
68 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) | 72 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) |
69 | 73 | ||
70 | if is_batched: | 74 | if is_batched: |
71 | result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] | 75 | result.input_ids = [self.expand_ids(batch) for batch in result.input_ids] |
72 | else: | 76 | else: |
73 | result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) | 77 | result.input_ids = self.expand_ids(result.input_ids) |
74 | 78 | ||
75 | return result | 79 | return result |