diff options
author | Volpeon <git@volpeon.ink> | 2023-03-07 07:11:51 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-07 07:11:51 +0100 |
commit | fe3113451fdde72ddccfc71639f0a2a1e146209a (patch) | |
tree | ba4114faf1bd00a642f97b5e7729ad74213c3b80 /training/functional.py | |
parent | Update (diff) | |
download | textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.gz textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.bz2 textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.zip |
Update
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py index 27a43c2..4565612 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -231,12 +231,16 @@ def add_placeholder_tokens( | |||
231 | embeddings: ManagedCLIPTextEmbeddings, | 231 | embeddings: ManagedCLIPTextEmbeddings, |
232 | placeholder_tokens: list[str], | 232 | placeholder_tokens: list[str], |
233 | initializer_tokens: list[str], | 233 | initializer_tokens: list[str], |
234 | num_vectors: Union[list[int], int] | 234 | num_vectors: Optional[Union[list[int], int]] = None, |
235 | ): | 235 | ): |
236 | initializer_token_ids = [ | 236 | initializer_token_ids = [ |
237 | tokenizer.encode(token, add_special_tokens=False) | 237 | tokenizer.encode(token, add_special_tokens=False) |
238 | for token in initializer_tokens | 238 | for token in initializer_tokens |
239 | ] | 239 | ] |
240 | |||
241 | if num_vectors is None: | ||
242 | num_vectors = [len(ids) for ids in initializer_token_ids] | ||
243 | |||
240 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | 244 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) |
241 | 245 | ||
242 | embeddings.resize(len(tokenizer)) | 246 | embeddings.resize(len(tokenizer)) |