diff options
author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /models/clip/tokenizer.py | |
parent | Fix LoRA training with DAdan (diff) | |
download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip |
Update
Diffstat (limited to 'models/clip/tokenizer.py')
-rw-r--r-- | models/clip/tokenizer.py | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 789b525..a866641 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -91,18 +91,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
91 | self.vector_shuffle = shuffle_none | 91 | self.vector_shuffle = shuffle_none |
92 | 92 | ||
93 | def add_multi_tokens( | 93 | def add_multi_tokens( |
94 | self, | 94 | self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1 |
95 | new_tokens: Union[str, list[str]], | ||
96 | num_vectors: Union[int, list[int]] = 1 | ||
97 | ) -> Union[list[int], list[list[int]]]: | 95 | ) -> Union[list[int], list[list[int]]]: |
98 | if isinstance(new_tokens, list): | 96 | if isinstance(new_tokens, list): |
99 | if isinstance(num_vectors, int): | 97 | if isinstance(num_vectors, int): |
100 | num_vectors = [num_vectors] * len(new_tokens) | 98 | num_vectors = [num_vectors] * len(new_tokens) |
101 | 99 | ||
102 | if len(num_vectors) != len(new_tokens): | 100 | if len(num_vectors) != len(new_tokens): |
103 | raise ValueError("Expected new_tokens and num_vectors to have the same len") | 101 | raise ValueError( |
102 | "Expected new_tokens and num_vectors to have the same len" | ||
103 | ) | ||
104 | 104 | ||
105 | return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] | 105 | return [ |
106 | self.add_multi_tokens(new_token, vecs) | ||
107 | for new_token, vecs in zip(new_tokens, num_vectors) | ||
108 | ] | ||
106 | 109 | ||
107 | if isinstance(num_vectors, list): | 110 | if isinstance(num_vectors, list): |
108 | raise ValueError("Expected num_vectors to be int for single token") | 111 | raise ValueError("Expected num_vectors to be int for single token") |
@@ -129,13 +132,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
129 | return [id] | 132 | return [id] |
130 | 133 | ||
131 | def expand_ids(self, ids: list[int]): | 134 | def expand_ids(self, ids: list[int]): |
132 | return [ | 135 | return [new_id for id in ids for new_id in self.expand_id(id)] |
133 | new_id | ||
134 | for id in ids | ||
135 | for new_id in self.expand_id(id) | ||
136 | ] | ||
137 | 136 | ||
138 | def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): | 137 | def expand_batched_ids( |
138 | self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]] | ||
139 | ): | ||
139 | if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): | 140 | if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): |
140 | return [self.expand_ids(batch) for batch in input_ids] | 141 | return [self.expand_ids(batch) for batch in input_ids] |
141 | else: | 142 | else: |