diff options
author | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
commit | 6c64f769043c8212b1a5778e857af691a828798d (patch) | |
tree | fe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /common.py | |
parent | Update (diff) | |
download | textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2 textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip |
Various cleanups
Diffstat (limited to 'common.py')
-rw-r--r-- | common.py | 44 |
1 files changed, 0 insertions, 44 deletions
diff --git a/common.py b/common.py deleted file mode 100644 index 0887197..0000000 --- a/common.py +++ /dev/null | |||
@@ -1,44 +0,0 @@ | |||
1 | from pathlib import Path | ||
2 | import json | ||
3 | |||
4 | from models.clip.embeddings import ManagedCLIPTextEmbeddings | ||
5 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
6 | |||
7 | from safetensors import safe_open | ||
8 | |||
9 | |||
10 | def load_config(filename): | ||
11 | with open(filename, 'rt') as f: | ||
12 | config = json.load(f) | ||
13 | |||
14 | args = config["args"] | ||
15 | |||
16 | if "base" in config: | ||
17 | args = load_config(Path(filename).parent.joinpath(config["base"])) | args | ||
18 | |||
19 | return args | ||
20 | |||
21 | |||
22 | def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): | ||
23 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
24 | return [] | ||
25 | |||
26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] | ||
27 | |||
28 | new_tokens = [] | ||
29 | new_embeds = [] | ||
30 | |||
31 | for filename in filenames: | ||
32 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
33 | embed = file.get_tensor("embed") | ||
34 | |||
35 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | ||
36 | new_tokens.append(added) | ||
37 | new_embeds.append(embed) | ||
38 | |||
39 | embeddings.resize(len(tokenizer)) | ||
40 | |||
41 | for (new_token, embeds) in zip(new_tokens, new_embeds): | ||
42 | embeddings.add_embed(new_token.ids, embeds) | ||
43 | |||
44 | return new_tokens | ||