From e2d3a62bce63fcde940395a1c5618c4eb43385a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:25:13 +0100 Subject: Cleanup --- models/clip/tokenizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'models') diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 39c41ed..789b525 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -55,6 +55,9 @@ def shuffle_auto(tokens: list[int]): return shuffle_all(tokens) +ShuffleAlgorithm = Union[bool, Literal["all", "trailing", "leading", "between", "off"]] + + class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -73,7 +76,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): def set_dropout(self, dropout: float): self.dropout = dropout - def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): + def set_use_vector_shuffle(self, algorithm: ShuffleAlgorithm): if algorithm == "leading": self.vector_shuffle = shuffle_leading elif algorithm == "trailing": -- cgit v1.2.3-70-g09d2