From a72b6260c117cabe4fcb2996cce4f870986df99b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 3 Jan 2023 12:40:16 +0100 Subject: Added vector dropout --- models/clip/tokenizer.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) (limited to 'models/clip') diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index bd0bd21..11a3df0 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -6,6 +6,12 @@ import numpy as np from transformers import CLIPTokenizer +def dropout(tokens: list[int], dropout: float): + if dropout != 0: + tokens = [token for token in tokens if np.random.random() > dropout] + return tokens + + def shuffle_all(tokens: list[int]): if len(tokens) >= 2: tokens = copy.copy(tokens) @@ -59,7 +65,18 @@ class MultiCLIPTokenizer(CLIPTokenizer): super().__init__(*args, **kwargs) self.token_map: dict[int, list[int]] = {} - self.vector_shuffle = shuffle_none + self.is_training = False + self.vector_shuffle = shuffle_auto + self.dropout = 0 + + def train(self): + self.is_training = True + + def eval(self): + self.is_training = False + + def set_dropout(self, dropout: float): + self.dropout = dropout def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): if algorithm == "leading": @@ -105,7 +122,13 @@ class MultiCLIPTokenizer(CLIPTokenizer): return MultiCLIPTokenizerItem(new_tokens, ids) def expand_id(self, id: int): - return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] + if id in self.token_map: + ids = self.token_map[id] + if self.is_training: + ids = dropout(self.vector_shuffle(ids), self.dropout) + return ids + else: + return [id] def expand_ids(self, ids: list[int]): return [ -- cgit v1.2.3-70-g09d2