From b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 11:36:00 +0100 Subject: Fixed accuracy calc, other improvements --- models/clip/tokenizer.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'models') diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index a3e6e70..37d69a9 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -15,6 +15,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.token_map: dict[int, list[int]] = {} + self.vector_shuffle = False + + def set_use_vector_shuffle(self, enable: bool): + self.vector_shuffle = enable def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: if isinstance(new_tokens, list): @@ -42,11 +46,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): return MultiCLIPTokenizerItem(new_tokens, ids) - def expand_id(self, id: int, vector_shuffle=True): + def expand_id(self, id: int): if id in self.token_map: tokens = self.token_map[id] - if vector_shuffle and len(tokens) > 2: + if self.vector_shuffle and len(tokens) > 2: subtokens = tokens[1:-1] np.random.shuffle(subtokens) tokens = tokens[:1] + subtokens + tokens[-1:] @@ -55,21 +59,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): else: return [id] - def expand_ids(self, ids: list[int], vector_shuffle=True): + def expand_ids(self, ids: list[int]): return [ new_id for id in ids - for new_id in self.expand_id(id, vector_shuffle) + for new_id in self.expand_id(id) ] - def _call_one(self, text, *args, vector_shuffle=True, **kwargs): + def _call_one(self, text, *args, **kwargs): result = super()._call_one(text, *args, **kwargs) is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) if is_batched: - result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] + result.input_ids = [self.expand_ids(batch) for batch in result.input_ids] else: - result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) + result.input_ids = self.expand_ids(result.input_ids) return result -- cgit v1.2.3-54-g00ecf