diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/tokenizer.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 39c41ed..789b525 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -55,6 +55,9 @@ def shuffle_auto(tokens: list[int]): | |||
55 | return shuffle_all(tokens) | 55 | return shuffle_all(tokens) |
56 | 56 | ||
57 | 57 | ||
58 | ShuffleAlgorithm = Union[bool, Literal["all", "trailing", "leading", "between", "off"]] | ||
59 | |||
60 | |||
58 | class MultiCLIPTokenizer(CLIPTokenizer): | 61 | class MultiCLIPTokenizer(CLIPTokenizer): |
59 | def __init__(self, *args, **kwargs): | 62 | def __init__(self, *args, **kwargs): |
60 | super().__init__(*args, **kwargs) | 63 | super().__init__(*args, **kwargs) |
@@ -73,7 +76,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
73 | def set_dropout(self, dropout: float): | 76 | def set_dropout(self, dropout: float): |
74 | self.dropout = dropout | 77 | self.dropout = dropout |
75 | 78 | ||
76 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): | 79 | def set_use_vector_shuffle(self, algorithm: ShuffleAlgorithm): |
77 | if algorithm == "leading": | 80 | if algorithm == "leading": |
78 | self.vector_shuffle = shuffle_leading | 81 | self.vector_shuffle = shuffle_leading |
79 | elif algorithm == "trailing": | 82 | elif algorithm == "trailing": |