diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 1 |
1 files changed, 1 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index 3776eb2..19348e5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -535,6 +535,7 @@ def main(): | |||
535 | ] | 535 | ] |
536 | 536 | ||
537 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 537 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
538 | embeddings.resize(len(tokenizer)) | ||
538 | 539 | ||
539 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): | 540 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
540 | embeddings.add_embed(new_token.ids, init_ids) | 541 | embeddings.add_embed(new_token.ids, init_ids) |