From 56edf85c8b80d49c998bcf26392cce50d552137a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 23:09:41 +0100 Subject: Update --- common.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) (limited to 'common.py') diff --git a/common.py b/common.py index 691be4e..0887197 100644 --- a/common.py +++ b/common.py @@ -24,13 +24,21 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC return [] filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] - tokens = [filename.stem for filename in filenames] - for filename in embeddings_dir.iterdir(): - if filename.is_file(): - with safe_open(filename, framework="pt", device="cpu") as file: - embed = file.get_tensor("embed") - added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) - embeddings.add_embed(added.ids, embed) + new_tokens = [] + new_embeds = [] - return tokens + for filename in filenames: + with safe_open(filename, framework="pt", device="cpu") as file: + embed = file.get_tensor("embed") + + added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) + new_tokens.append(added) + new_embeds.append(embed) + + embeddings.resize(len(tokenizer)) + + for (new_token, embeds) in zip(new_tokens, new_embeds): + embeddings.add_embed(new_token.ids, embeds) + + return new_tokens -- cgit v1.2.3-54-g00ecf