summaryrefslogtreecommitdiffstats
path: root/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 10:19:38 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 10:19:38 +0100
commit6c64f769043c8212b1a5778e857af691a828798d (patch)
treefe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /util.py
parentUpdate (diff)
downloadtextual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz
textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2
textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip
Various cleanups
Diffstat (limited to 'util.py')
-rw-r--r--util.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/util.py b/util.py
new file mode 100644
index 0000000..545bcb5
--- /dev/null
+++ b/util.py
@@ -0,0 +1,45 @@
1from pathlib import Path
2import json
3
4from models.clip.embeddings import ManagedCLIPTextEmbeddings
5from models.clip.tokenizer import MultiCLIPTokenizer
6
7from safetensors import safe_open
8
9
10def load_config(filename):
11 with open(filename, 'rt') as f:
12 config = json.load(f)
13
14 args = config["args"]
15
16 if "base" in config:
17 args = load_config(Path(filename).parent.joinpath(config["base"])) | args
18
19 return args
20
21
22def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path):
23 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
24 return []
25
26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()]
27 tokens = [filename.stem for filename in filenames]
28
29 new_ids: list[list[int]] = []
30 new_embeds = []
31
32 for filename in filenames:
33 with safe_open(filename, framework="pt", device="cpu") as file:
34 embed = file.get_tensor("embed")
35
36 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
37 new_ids.append(added)
38 new_embeds.append(embed)
39
40 embeddings.resize(len(tokenizer))
41
42 for (new_id, embeds) in zip(new_ids, new_embeds):
43 embeddings.add_embed(new_id, embeds)
44
45 return tokens, new_ids