summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-03 12:40:16 +0100
committerVolpeon <git@volpeon.ink>2023-01-03 12:40:16 +0100
commita72b6260c117cabe4fcb2996cce4f870986df99b (patch)
tree7c9c7704c6ef60a4ab886d5acbce4e6e22398b56 /models/clip
parentFixed LR finder (diff)
downloadtextual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.gz
textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.bz2
textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.zip
Added vector dropout
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/tokenizer.py27
1 files changed, 25 insertions, 2 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index bd0bd21..11a3df0 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -6,6 +6,12 @@ import numpy as np
6from transformers import CLIPTokenizer 6from transformers import CLIPTokenizer
7 7
8 8
9def dropout(tokens: list[int], dropout: float):
10 if dropout != 0:
11 tokens = [token for token in tokens if np.random.random() > dropout]
12 return tokens
13
14
9def shuffle_all(tokens: list[int]): 15def shuffle_all(tokens: list[int]):
10 if len(tokens) >= 2: 16 if len(tokens) >= 2:
11 tokens = copy.copy(tokens) 17 tokens = copy.copy(tokens)
@@ -59,7 +65,18 @@ class MultiCLIPTokenizer(CLIPTokenizer):
59 super().__init__(*args, **kwargs) 65 super().__init__(*args, **kwargs)
60 66
61 self.token_map: dict[int, list[int]] = {} 67 self.token_map: dict[int, list[int]] = {}
62 self.vector_shuffle = shuffle_none 68 self.is_training = False
69 self.vector_shuffle = shuffle_auto
70 self.dropout = 0
71
72 def train(self):
73 self.is_training = True
74
75 def eval(self):
76 self.is_training = False
77
78 def set_dropout(self, dropout: float):
79 self.dropout = dropout
63 80
64 def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): 81 def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]):
65 if algorithm == "leading": 82 if algorithm == "leading":
@@ -105,7 +122,13 @@ class MultiCLIPTokenizer(CLIPTokenizer):
105 return MultiCLIPTokenizerItem(new_tokens, ids) 122 return MultiCLIPTokenizerItem(new_tokens, ids)
106 123
107 def expand_id(self, id: int): 124 def expand_id(self, id: int):
108 return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] 125 if id in self.token_map:
126 ids = self.token_map[id]
127 if self.is_training:
128 ids = dropout(self.vector_shuffle(ids), self.dropout)
129 return ids
130 else:
131 return [id]
109 132
110 def expand_ids(self, ids: list[int]): 133 def expand_ids(self, ids: list[int]):
111 return [ 134 return [