diff options
author | Volpeon <git@volpeon.ink> | 2022-12-31 23:09:41 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-31 23:09:41 +0100 |
commit | 56edf85c8b80d49c998bcf26392cce50d552137a (patch) | |
tree | c4fa91f1dc951329a6d276731308d657eec644c8 /common.py | |
parent | Bugfixes for multi-vector token handling (diff) | |
download | textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.tar.gz textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.tar.bz2 textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.zip |
Update
Diffstat (limited to 'common.py')
-rw-r--r-- | common.py | 24 |
1 files changed, 16 insertions, 8 deletions
@@ -24,13 +24,21 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC | |||
24 | return [] | 24 | return [] |
25 | 25 | ||
26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] | 26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] |
27 | tokens = [filename.stem for filename in filenames] | ||
28 | 27 | ||
29 | for filename in embeddings_dir.iterdir(): | 28 | new_tokens = [] |
30 | if filename.is_file(): | 29 | new_embeds = [] |
31 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
32 | embed = file.get_tensor("embed") | ||
33 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | ||
34 | embeddings.add_embed(added.ids, embed) | ||
35 | 30 | ||
36 | return tokens | 31 | for filename in filenames: |
32 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
33 | embed = file.get_tensor("embed") | ||
34 | |||
35 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | ||
36 | new_tokens.append(added) | ||
37 | new_embeds.append(embed) | ||
38 | |||
39 | embeddings.resize(len(tokenizer)) | ||
40 | |||
41 | for (new_token, embeds) in zip(new_tokens, new_embeds): | ||
42 | embeddings.add_embed(new_token.ids, embeds) | ||
43 | |||
44 | return new_tokens | ||