diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-31 23:09:41 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-31 23:09:41 +0100 |
| commit | 56edf85c8b80d49c998bcf26392cce50d552137a (patch) | |
| tree | c4fa91f1dc951329a6d276731308d657eec644c8 /common.py | |
| parent | Bugfixes for multi-vector token handling (diff) | |
| download | textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.tar.gz textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.tar.bz2 textual-inversion-diff-56edf85c8b80d49c998bcf26392cce50d552137a.zip | |
Update
Diffstat (limited to 'common.py')
| -rw-r--r-- | common.py | 24 |
1 files changed, 16 insertions, 8 deletions
| @@ -24,13 +24,21 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC | |||
| 24 | return [] | 24 | return [] |
| 25 | 25 | ||
| 26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] | 26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] |
| 27 | tokens = [filename.stem for filename in filenames] | ||
| 28 | 27 | ||
| 29 | for filename in embeddings_dir.iterdir(): | 28 | new_tokens = [] |
| 30 | if filename.is_file(): | 29 | new_embeds = [] |
| 31 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
| 32 | embed = file.get_tensor("embed") | ||
| 33 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | ||
| 34 | embeddings.add_embed(added.ids, embed) | ||
| 35 | 30 | ||
| 36 | return tokens | 31 | for filename in filenames: |
| 32 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
| 33 | embed = file.get_tensor("embed") | ||
| 34 | |||
| 35 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | ||
| 36 | new_tokens.append(added) | ||
| 37 | new_embeds.append(embed) | ||
| 38 | |||
| 39 | embeddings.resize(len(tokenizer)) | ||
| 40 | |||
| 41 | for (new_token, embeds) in zip(new_tokens, new_embeds): | ||
| 42 | embeddings.add_embed(new_token.ids, embeds) | ||
| 43 | |||
| 44 | return new_tokens | ||
