diff options
Diffstat (limited to 'models/clip/tokenizer.py')
-rw-r--r-- | models/clip/tokenizer.py | 39 |
1 files changed, 26 insertions, 13 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 7e08287..63566e0 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -44,20 +44,33 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
44 | 44 | ||
45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) | 45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) |
46 | 46 | ||
47 | def encode(self, *args, vector_shuffle=True, **kwargs): | 47 | def expand_id(self, id: int, vector_shuffle=True): |
48 | ids = super().encode(*args, **kwargs) | 48 | if id in self.token_map: |
49 | new_ids = [] | 49 | tokens = self.token_map[id] |
50 | 50 | ||
51 | for id in ids: | 51 | if vector_shuffle: |
52 | if id in self.token_map: | 52 | tokens = copy.copy(tokens) |
53 | tokens = self.token_map[id] | 53 | np.random.shuffle(tokens) |
54 | 54 | ||
55 | if vector_shuffle: | 55 | return tokens |
56 | tokens = copy.copy(tokens) | 56 | else: |
57 | np.random.shuffle(tokens) | 57 | return [id] |
58 | 58 | ||
59 | new_ids = new_ids + self.token_map[id] | 59 | def expand_ids(self, ids: list[int], vector_shuffle=True): |
60 | else: | 60 | return [ |
61 | new_ids.append(id) | 61 | new_id |
62 | for id in ids | ||
63 | for new_id in self.expand_id(id, vector_shuffle) | ||
64 | ] | ||
62 | 65 | ||
63 | return new_ids | 66 | def _call_one(self, text, *args, vector_shuffle=True, **kwargs): |
67 | result = super()._call_one(text, *args, **kwargs) | ||
68 | |||
69 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) | ||
70 | |||
71 | if is_batched: | ||
72 | result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] | ||
73 | else: | ||
74 | result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) | ||
75 | |||
76 | return result | ||