diff options
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/embeddings.py | 11 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 74 |
2 files changed, 70 insertions, 15 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..8602142 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -120,3 +120,14 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe | |||
| 120 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 120 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) |
| 121 | text_encoder.text_model.embeddings = text_embeddings | 121 | text_encoder.text_model.embeddings = text_embeddings |
| 122 | return text_embeddings | 122 | return text_embeddings |
| 123 | |||
| 124 | |||
| 125 | def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings: | ||
| 126 | text_encoder.text_model.embeddings.make_permanent() | ||
| 127 | |||
| 128 | text_embeddings = CLIPTextEmbeddings(text_encoder.config) | ||
| 129 | text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding | ||
| 130 | text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding | ||
| 131 | text_encoder.text_model.embeddings = text_embeddings | ||
| 132 | |||
| 133 | return text_embeddings | ||
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 37d69a9..ed9774e 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -1,11 +1,54 @@ | |||
| 1 | import copy | 1 | import copy |
| 2 | from typing import NamedTuple, Union | 2 | from typing import NamedTuple, Union, Literal |
| 3 | 3 | ||
| 4 | import numpy as np | 4 | import numpy as np |
| 5 | 5 | ||
| 6 | from transformers import CLIPTokenizer | 6 | from transformers import CLIPTokenizer |
| 7 | 7 | ||
| 8 | 8 | ||
| 9 | def shuffle_all(tokens: list[int]): | ||
| 10 | if len(tokens) >= 2: | ||
| 11 | tokens = copy.copy(tokens) | ||
| 12 | np.random.shuffle(tokens) | ||
| 13 | return tokens | ||
| 14 | |||
| 15 | |||
| 16 | def shuffle_leading(tokens: list[int]): | ||
| 17 | if len(tokens) >= 3: | ||
| 18 | subtokens = tokens[:-1] | ||
| 19 | np.random.shuffle(subtokens) | ||
| 20 | tokens = subtokens + tokens[-1:] | ||
| 21 | return tokens | ||
| 22 | |||
| 23 | |||
| 24 | def shuffle_trailing(tokens: list[int]): | ||
| 25 | if len(tokens) >= 3: | ||
| 26 | subtokens = tokens[1:] | ||
| 27 | np.random.shuffle(subtokens) | ||
| 28 | tokens = tokens[:1] + subtokens | ||
| 29 | return tokens | ||
| 30 | |||
| 31 | |||
| 32 | def shuffle_between(tokens: list[int]): | ||
| 33 | if len(tokens) >= 4: | ||
| 34 | subtokens = tokens[1:-1] | ||
| 35 | np.random.shuffle(subtokens) | ||
| 36 | tokens = tokens[:1] + subtokens + tokens[-1:] | ||
| 37 | return tokens | ||
| 38 | |||
| 39 | |||
| 40 | def shuffle_none(tokens: list[int]): | ||
| 41 | return tokens | ||
| 42 | |||
| 43 | |||
| 44 | def shuffle_auto(tokens: list[int]): | ||
| 45 | if len(tokens) >= 4: | ||
| 46 | return shuffle_between(tokens) | ||
| 47 | if len(tokens) >= 3: | ||
| 48 | return shuffle_trailing(tokens) | ||
| 49 | return shuffle_all(tokens) | ||
| 50 | |||
| 51 | |||
| 9 | class MultiCLIPTokenizerItem(NamedTuple): | 52 | class MultiCLIPTokenizerItem(NamedTuple): |
| 10 | token: str | 53 | token: str |
| 11 | ids: list[int] | 54 | ids: list[int] |
| @@ -15,10 +58,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 15 | def __init__(self, *args, **kwargs): | 58 | def __init__(self, *args, **kwargs): |
| 16 | super().__init__(*args, **kwargs) | 59 | super().__init__(*args, **kwargs) |
| 17 | self.token_map: dict[int, list[int]] = {} | 60 | self.token_map: dict[int, list[int]] = {} |
| 18 | self.vector_shuffle = False | 61 | self.vector_shuffle = shuffle_none |
| 19 | 62 | ||
| 20 | def set_use_vector_shuffle(self, enable: bool): | 63 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): |
| 21 | self.vector_shuffle = enable | 64 | if algorithm == "leading": |
| 65 | self.vector_shuffle = shuffle_leading | ||
| 66 | elif algorithm == "trailing": | ||
| 67 | self.vector_shuffle = shuffle_trailing | ||
| 68 | elif algorithm == "between": | ||
| 69 | self.vector_shuffle = shuffle_between | ||
| 70 | elif algorithm == "auto": | ||
| 71 | self.vector_shuffle = shuffle_auto | ||
| 72 | elif algorithm == True or algorithm == "all": | ||
| 73 | self.vector_shuffle = shuffle_all | ||
| 74 | else: | ||
| 75 | self.vector_shuffle = shuffle_none | ||
| 22 | 76 | ||
| 23 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | 77 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: |
| 24 | if isinstance(new_tokens, list): | 78 | if isinstance(new_tokens, list): |
| @@ -47,17 +101,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 47 | return MultiCLIPTokenizerItem(new_tokens, ids) | 101 | return MultiCLIPTokenizerItem(new_tokens, ids) |
| 48 | 102 | ||
| 49 | def expand_id(self, id: int): | 103 | def expand_id(self, id: int): |
| 50 | if id in self.token_map: | 104 | return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id] |
| 51 | tokens = self.token_map[id] | ||
| 52 | |||
| 53 | if self.vector_shuffle and len(tokens) > 2: | ||
| 54 | subtokens = tokens[1:-1] | ||
| 55 | np.random.shuffle(subtokens) | ||
| 56 | tokens = tokens[:1] + subtokens + tokens[-1:] | ||
| 57 | |||
| 58 | return tokens | ||
| 59 | else: | ||
| 60 | return [id] | ||
| 61 | 105 | ||
| 62 | def expand_ids(self, ids: list[int]): | 106 | def expand_ids(self, ids: list[int]): |
| 63 | return [ | 107 | return [ |
