diff options
Diffstat (limited to 'models/clip')
-rw-r--r-- | models/clip/embeddings.py | 11 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 76 |
2 files changed, 71 insertions, 16 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..8602142 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -120,3 +120,14 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe | |||
120 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 120 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) |
121 | text_encoder.text_model.embeddings = text_embeddings | 121 | text_encoder.text_model.embeddings = text_embeddings |
122 | return text_embeddings | 122 | return text_embeddings |
123 | |||
124 | |||
125 | def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings: | ||
126 | text_encoder.text_model.embeddings.make_permanent() | ||
127 | |||
128 | text_embeddings = CLIPTextEmbeddings(text_encoder.config) | ||
129 | text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding | ||
130 | text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding | ||
131 | text_encoder.text_model.embeddings = text_embeddings | ||
132 | |||
133 | return text_embeddings | ||
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 37d69a9..ed9774e 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -1,11 +1,54 @@ | |||
1 | import copy | 1 | import copy |
2 | from typing import NamedTuple, Union | 2 | from typing import NamedTuple, Union, Literal |
3 | 3 | ||
4 | import numpy as np | 4 | import numpy as np |
5 | 5 | ||
6 | from transformers import CLIPTokenizer | 6 | from transformers import CLIPTokenizer |
7 | 7 | ||
8 | 8 | ||
9 | def shuffle_all(tokens: list[int]): | ||
10 | if len(tokens) >= 2: | ||
11 | tokens = copy.copy(tokens) | ||
12 | np.random.shuffle(tokens) | ||
13 | return tokens | ||
14 | |||
15 | |||
16 | def shuffle_leading(tokens: list[int]): | ||
17 | if len(tokens) >= 3: | ||
18 | subtokens = tokens[:-1] | ||
19 | np.random.shuffle(subtokens) | ||
20 | tokens = subtokens + tokens[-1:] | ||
21 | return tokens | ||
22 | |||
23 | |||
24 | def shuffle_trailing(tokens: list[int]): | ||
25 | if len(tokens) >= 3: | ||
26 | subtokens = tokens[1:] | ||
27 | np.random.shuffle(subtokens) | ||
28 | tokens = tokens[:1] + subtokens | ||
29 | return tokens | ||
30 | |||
31 | |||
32 | def shuffle_between(tokens: list[int]): | ||
33 | if len(tokens) >= 4: | ||
34 | subtokens = tokens[1:-1] | ||
35 | np.random.shuffle(subtokens) | ||
36 | tokens = tokens[:1] + subtokens + tokens[-1:] | ||
37 | return tokens | ||
38 | |||
39 | |||
40 | def shuffle_none(tokens: list[int]): | ||
41 | return tokens | ||
42 | |||
43 | |||
44 | def shuffle_auto(tokens: list[int]): | ||
45 | if len(tokens) >= 4: | ||
46 | return shuffle_between(tokens) | ||
47 | if len(tokens) >= 3: | ||
48 | return shuffle_trailing(tokens) | ||
49 | return shuffle_all(tokens) | ||
50 | |||
51 | |||
9 | class MultiCLIPTokenizerItem(NamedTuple): | 52 | class MultiCLIPTokenizerItem(NamedTuple): |
10 | token: str | 53 | token: str |
11 | ids: list[int] | 54 | ids: list[int] |
@@ -15,10 +58,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
15 | def __init__(self, *args, **kwargs): | 58 | def __init__(self, *args, **kwargs): |
16 | super().__init__(*args, **kwargs) | 59 | super().__init__(*args, **kwargs) |
17 | self.token_map: dict[int, list[int]] = {} | 60 | self.token_map: dict[int, list[int]] = {} |
18 | self.vector_shuffle = False | 61 | self.vector_shuffle = shuffle_none |
19 | 62 | ||
20 | def set_use_vector_shuffle(self, enable: bool): | 63 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): |
21 | self.vector_shuffle = enable | 64 | if algorithm == "leading": |
65 | self.vector_shuffle = shuffle_leading | ||
66 | elif algorithm == "trailing": | ||
67 | self.vector_shuffle = shuffle_trailing | ||
68 | elif algorithm == "between": | ||
69 | self.vector_shuffle = shuffle_between | ||
70 | elif algorithm == "auto": | ||
71 | self.vector_shuffle = shuffle_auto | ||
72 | elif algorithm == True or algorithm == "all": | ||
73 | self.vector_shuffle = shuffle_all | ||
74 | else: | ||
75 | self.vector_shuffle = shuffle_none | ||
22 | 76 | ||
23 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | 77 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: |
24 | if isinstance(new_tokens, list): | 78 | if isinstance(new_tokens, list): |
@@ -47,17 +101,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
47 | return MultiCLIPTokenizerItem(new_tokens, ids) | 101 | return MultiCLIPTokenizerItem(new_tokens, ids) |
48 | 102 | ||
49 | def expand_id(self, id: int): | 103 | def expand_id(self, id: int): |
50 | if id in self.token_map: | 104 | return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] |
51 | tokens = self.token_map[id] | ||
52 | |||
53 | if self.vector_shuffle and len(tokens) > 2: | ||
54 | subtokens = tokens[1:-1] | ||
55 | np.random.shuffle(subtokens) | ||
56 | tokens = tokens[:1] + subtokens + tokens[-1:] | ||
57 | |||
58 | return tokens | ||
59 | else: | ||
60 | return [id] | ||
61 | 105 | ||
62 | def expand_ids(self, ids: list[int]): | 106 | def expand_ids(self, ids: list[int]): |
63 | return [ | 107 | return [ |