From dc463a6b8ef120b7a0643569b66f9109ed38c652 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 14:07:44 +0100 Subject: Simplified multi-vector embedding code --- train_ti.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 3a5cfed..3776eb2 100644 --- a/train_ti.py +++ b/train_ti.py @@ -438,7 +438,7 @@ class Checkpointer(CheckpointerBase): for new_token in self.new_tokens: text_encoder.text_model.embeddings.save_embed( - new_token.multi_ids, + new_token.ids, checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") ) @@ -537,8 +537,7 @@ def main(): new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): - embeddings.add_embed(new_token.placeholder_id) - embeddings.add_embed(new_token.multi_ids, init_ids) + embeddings.add_embed(new_token.ids, init_ids) print(f"Added {len(new_tokens)} new tokens.") -- cgit v1.2.3-54-g00ecf