summaryrefslogtreecommitdiffstats
path: root/common.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 23:09:41 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 23:09:41 +0100
commit56edf85c8b80d49c998bcf26392cce50d552137a (patch)
treec4fa91f1dc951329a6d276731308d657eec644c8 /common.py
parentBugfixes for multi-vector token handling (diff)
downloadtextual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.tar.gz
textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.tar.bz2
textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.zip
Update
Diffstat (limited to 'common.py')
-rw-r--r--common.py24
1 files changed, 16 insertions, 8 deletions
diff --git a/common.py b/common.py
index 691be4e..0887197 100644
--- a/common.py
+++ b/common.py
@@ -24,13 +24,21 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC
24 return [] 24 return []
25 25
26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] 26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()]
27 tokens = [filename.stem for filename in filenames]
28 27
29 for filename in embeddings_dir.iterdir(): 28 new_tokens = []
30 if filename.is_file(): 29 new_embeds = []
31 with safe_open(filename, framework="pt", device="cpu") as file:
32 embed = file.get_tensor("embed")
33 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
34 embeddings.add_embed(added.ids, embed)
35 30
36 return tokens 31 for filename in filenames:
32 with safe_open(filename, framework="pt", device="cpu") as file:
33 embed = file.get_tensor("embed")
34
35 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
36 new_tokens.append(added)
37 new_embeds.append(embed)
38
39 embeddings.resize(len(tokenizer))
40
41 for (new_token, embeds) in zip(new_tokens, new_embeds):
42 embeddings.add_embed(new_token.ids, embeds)
43
44 return new_tokens