From a72b6260c117cabe4fcb2996cce4f870986df99b Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Tue, 3 Jan 2023 12:40:16 +0100
Subject: Added vector dropout

---
 models/clip/tokenizer.py | 27 +++++++++++++++++++++++++--
 1 file changed, 25 insertions(+), 2 deletions(-)

(limited to 'models/clip')

diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index bd0bd21..11a3df0 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -6,6 +6,12 @@ import numpy as np
 from transformers import CLIPTokenizer
 
 
+def dropout(tokens: list[int], dropout: float):
+    if dropout != 0:
+        tokens = [token for token in tokens if np.random.random() > dropout]
+    return tokens
+
+
 def shuffle_all(tokens: list[int]):
     if len(tokens) >= 2:
         tokens = copy.copy(tokens)
@@ -59,7 +65,18 @@ class MultiCLIPTokenizer(CLIPTokenizer):
         super().__init__(*args, **kwargs)
 
         self.token_map: dict[int, list[int]] = {}
-        self.vector_shuffle = shuffle_none
+        self.is_training = False
+        self.vector_shuffle = shuffle_auto
+        self.dropout = 0
+
+    def train(self):
+        self.is_training = True
+
+    def eval(self):
+        self.is_training = False
+
+    def set_dropout(self, dropout: float):
+        self.dropout = dropout
 
     def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]):
         if algorithm == "leading":
@@ -105,7 +122,13 @@ class MultiCLIPTokenizer(CLIPTokenizer):
         return MultiCLIPTokenizerItem(new_tokens, ids)
 
     def expand_id(self, id: int):
-        return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id]
+        if id in self.token_map:
+            ids = self.token_map[id]
+            if self.is_training:
+                ids = dropout(self.vector_shuffle(ids), self.dropout)
+            return ids
+        else:
+            return [id]
 
     def expand_ids(self, ids: list[int]):
         return [
-- 
cgit v1.2.3-70-g09d2