diff options
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/tokenizer.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 39c41ed..789b525 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -55,6 +55,9 @@ def shuffle_auto(tokens: list[int]): | |||
| 55 | return shuffle_all(tokens) | 55 | return shuffle_all(tokens) |
| 56 | 56 | ||
| 57 | 57 | ||
| 58 | ShuffleAlgorithm = Union[bool, Literal["all", "trailing", "leading", "between", "off"]] | ||
| 59 | |||
| 60 | |||
| 58 | class MultiCLIPTokenizer(CLIPTokenizer): | 61 | class MultiCLIPTokenizer(CLIPTokenizer): |
| 59 | def __init__(self, *args, **kwargs): | 62 | def __init__(self, *args, **kwargs): |
| 60 | super().__init__(*args, **kwargs) | 63 | super().__init__(*args, **kwargs) |
| @@ -73,7 +76,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 73 | def set_dropout(self, dropout: float): | 76 | def set_dropout(self, dropout: float): |
| 74 | self.dropout = dropout | 77 | self.dropout = dropout |
| 75 | 78 | ||
| 76 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): | 79 | def set_use_vector_shuffle(self, algorithm: ShuffleAlgorithm): |
| 77 | if algorithm == "leading": | 80 | if algorithm == "leading": |
| 78 | self.vector_shuffle = shuffle_leading | 81 | self.vector_shuffle = shuffle_leading |
| 79 | elif algorithm == "trailing": | 82 | elif algorithm == "trailing": |
