summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 14:07:44 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 14:07:44 +0100
commitdc463a6b8ef120b7a0643569b66f9109ed38c652 (patch)
treeae742a988b8541009a980c8b2f719696f9d7df27 /train_ti.py
parentFixes (diff)
downloadtextual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.gz
textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.bz2
textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.zip
Simplified multi-vector embedding code
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 3a5cfed..3776eb2 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -438,7 +438,7 @@ class Checkpointer(CheckpointerBase):
438 438
439 for new_token in self.new_tokens: 439 for new_token in self.new_tokens:
440 text_encoder.text_model.embeddings.save_embed( 440 text_encoder.text_model.embeddings.save_embed(
441 new_token.multi_ids, 441 new_token.ids,
442 checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") 442 checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin")
443 ) 443 )
444 444
@@ -537,8 +537,7 @@ def main():
537 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) 537 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
538 538
539 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): 539 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids):
540 embeddings.add_embed(new_token.placeholder_id) 540 embeddings.add_embed(new_token.ids, init_ids)
541 embeddings.add_embed(new_token.multi_ids, init_ids)
542 541
543 print(f"Added {len(new_tokens)} new tokens.") 542 print(f"Added {len(new_tokens)} new tokens.")
544 543