summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-07 07:11:51 +0100
committerVolpeon <git@volpeon.ink>2023-03-07 07:11:51 +0100
commitfe3113451fdde72ddccfc71639f0a2a1e146209a (patch)
treeba4114faf1bd00a642f97b5e7729ad74213c3b80 /training/functional.py
parentUpdate (diff)
downloadtextual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.gz
textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.bz2
textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.zip
Update
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py
index 27a43c2..4565612 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -231,12 +231,16 @@ def add_placeholder_tokens(
231 embeddings: ManagedCLIPTextEmbeddings, 231 embeddings: ManagedCLIPTextEmbeddings,
232 placeholder_tokens: list[str], 232 placeholder_tokens: list[str],
233 initializer_tokens: list[str], 233 initializer_tokens: list[str],
234 num_vectors: Union[list[int], int] 234 num_vectors: Optional[Union[list[int], int]] = None,
235): 235):
236 initializer_token_ids = [ 236 initializer_token_ids = [
237 tokenizer.encode(token, add_special_tokens=False) 237 tokenizer.encode(token, add_special_tokens=False)
238 for token in initializer_tokens 238 for token in initializer_tokens
239 ] 239 ]
240
241 if num_vectors is None:
242 num_vectors = [len(ids) for ids in initializer_token_ids]
243
240 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) 244 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
241 245
242 embeddings.resize(len(tokenizer)) 246 embeddings.resize(len(tokenizer))