summaryrefslogtreecommitdiffstats
path: root/models/clip/tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/clip/tokenizer.py')
-rw-r--r--models/clip/tokenizer.py39
1 files changed, 26 insertions, 13 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 7e08287..63566e0 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -44,20 +44,33 @@ class MultiCLIPTokenizer(CLIPTokenizer):
44 44
45 return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) 45 return MultiCLIPTokenizerItem(new_tokens, meta_id, ids)
46 46
47 def encode(self, *args, vector_shuffle=True, **kwargs): 47 def expand_id(self, id: int, vector_shuffle=True):
48 ids = super().encode(*args, **kwargs) 48 if id in self.token_map:
49 new_ids = [] 49 tokens = self.token_map[id]
50 50
51 for id in ids: 51 if vector_shuffle:
52 if id in self.token_map: 52 tokens = copy.copy(tokens)
53 tokens = self.token_map[id] 53 np.random.shuffle(tokens)
54 54
55 if vector_shuffle: 55 return tokens
56 tokens = copy.copy(tokens) 56 else:
57 np.random.shuffle(tokens) 57 return [id]
58 58
59 new_ids = new_ids + self.token_map[id] 59 def expand_ids(self, ids: list[int], vector_shuffle=True):
60 else: 60 return [
61 new_ids.append(id) 61 new_id
62 for id in ids
63 for new_id in self.expand_id(id, vector_shuffle)
64 ]
62 65
63 return new_ids 66 def _call_one(self, text, *args, vector_shuffle=True, **kwargs):
67 result = super()._call_one(text, *args, **kwargs)
68
69 is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list)
70
71 if is_batched:
72 result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids]
73 else:
74 result.input_ids = self.expand_ids(result.input_ids, vector_shuffle)
75
76 return result