diff options
author | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
commit | a72b6260c117cabe4fcb2996cce4f870986df99b (patch) | |
tree | 7c9c7704c6ef60a4ab886d5acbce4e6e22398b56 /models | |
parent | Fixed LR finder (diff) | |
download | textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.gz textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.bz2 textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.zip |
Added vector dropout
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/tokenizer.py | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index bd0bd21..11a3df0 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -6,6 +6,12 @@ import numpy as np | |||
6 | from transformers import CLIPTokenizer | 6 | from transformers import CLIPTokenizer |
7 | 7 | ||
8 | 8 | ||
9 | def dropout(tokens: list[int], dropout: float): | ||
10 | if dropout != 0: | ||
11 | tokens = [token for token in tokens if np.random.random() > dropout] | ||
12 | return tokens | ||
13 | |||
14 | |||
9 | def shuffle_all(tokens: list[int]): | 15 | def shuffle_all(tokens: list[int]): |
10 | if len(tokens) >= 2: | 16 | if len(tokens) >= 2: |
11 | tokens = copy.copy(tokens) | 17 | tokens = copy.copy(tokens) |
@@ -59,7 +65,18 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
59 | super().__init__(*args, **kwargs) | 65 | super().__init__(*args, **kwargs) |
60 | 66 | ||
61 | self.token_map: dict[int, list[int]] = {} | 67 | self.token_map: dict[int, list[int]] = {} |
62 | self.vector_shuffle = shuffle_none | 68 | self.is_training = False |
69 | self.vector_shuffle = shuffle_auto | ||
70 | self.dropout = 0 | ||
71 | |||
72 | def train(self): | ||
73 | self.is_training = True | ||
74 | |||
75 | def eval(self): | ||
76 | self.is_training = False | ||
77 | |||
78 | def set_dropout(self, dropout: float): | ||
79 | self.dropout = dropout | ||
63 | 80 | ||
64 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): | 81 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): |
65 | if algorithm == "leading": | 82 | if algorithm == "leading": |
@@ -105,7 +122,13 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
105 | return MultiCLIPTokenizerItem(new_tokens, ids) | 122 | return MultiCLIPTokenizerItem(new_tokens, ids) |
106 | 123 | ||
107 | def expand_id(self, id: int): | 124 | def expand_id(self, id: int): |
108 | return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] | 125 | if id in self.token_map: |
126 | ids = self.token_map[id] | ||
127 | if self.is_training: | ||
128 | ids = dropout(self.vector_shuffle(ids), self.dropout) | ||
129 | return ids | ||
130 | else: | ||
131 | return [id] | ||
109 | 132 | ||
110 | def expand_ids(self, ids: list[int]): | 133 | def expand_ids(self, ids: list[int]): |
111 | return [ | 134 | return [ |