summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/clip/tokenizer.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 39c41ed..789b525 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -55,6 +55,9 @@ def shuffle_auto(tokens: list[int]):
55 return shuffle_all(tokens) 55 return shuffle_all(tokens)
56 56
57 57
58ShuffleAlgorithm = Union[bool, Literal["all", "trailing", "leading", "between", "off"]]
59
60
58class MultiCLIPTokenizer(CLIPTokenizer): 61class MultiCLIPTokenizer(CLIPTokenizer):
59 def __init__(self, *args, **kwargs): 62 def __init__(self, *args, **kwargs):
60 super().__init__(*args, **kwargs) 63 super().__init__(*args, **kwargs)
@@ -73,7 +76,7 @@ class MultiCLIPTokenizer(CLIPTokenizer):
73 def set_dropout(self, dropout: float): 76 def set_dropout(self, dropout: float):
74 self.dropout = dropout 77 self.dropout = dropout
75 78
76 def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): 79 def set_use_vector_shuffle(self, algorithm: ShuffleAlgorithm):
77 if algorithm == "leading": 80 if algorithm == "leading":
78 self.vector_shuffle = shuffle_leading 81 self.vector_shuffle = shuffle_leading
79 elif algorithm == "trailing": 82 elif algorithm == "trailing":