diff options
Diffstat (limited to 'models/clip/tokenizer.py')
| -rw-r--r-- | models/clip/tokenizer.py | 39 |
1 files changed, 26 insertions, 13 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 7e08287..63566e0 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -44,20 +44,33 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 44 | 44 | ||
| 45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) | 45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) |
| 46 | 46 | ||
| 47 | def encode(self, *args, vector_shuffle=True, **kwargs): | 47 | def expand_id(self, id: int, vector_shuffle=True): |
| 48 | ids = super().encode(*args, **kwargs) | 48 | if id in self.token_map: |
| 49 | new_ids = [] | 49 | tokens = self.token_map[id] |
| 50 | 50 | ||
| 51 | for id in ids: | 51 | if vector_shuffle: |
| 52 | if id in self.token_map: | 52 | tokens = copy.copy(tokens) |
| 53 | tokens = self.token_map[id] | 53 | np.random.shuffle(tokens) |
| 54 | 54 | ||
| 55 | if vector_shuffle: | 55 | return tokens |
| 56 | tokens = copy.copy(tokens) | 56 | else: |
| 57 | np.random.shuffle(tokens) | 57 | return [id] |
| 58 | 58 | ||
| 59 | new_ids = new_ids + self.token_map[id] | 59 | def expand_ids(self, ids: list[int], vector_shuffle=True): |
| 60 | else: | 60 | return [ |
| 61 | new_ids.append(id) | 61 | new_id |
| 62 | for id in ids | ||
| 63 | for new_id in self.expand_id(id, vector_shuffle) | ||
| 64 | ] | ||
| 62 | 65 | ||
| 63 | return new_ids | 66 | def _call_one(self, text, *args, vector_shuffle=True, **kwargs): |
| 67 | result = super()._call_one(text, *args, **kwargs) | ||
| 68 | |||
| 69 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) | ||
| 70 | |||
| 71 | if is_batched: | ||
| 72 | result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] | ||
| 73 | else: | ||
| 74 | result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) | ||
| 75 | |||
| 76 | return result | ||
