summaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/files.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/util/files.py b/util/files.py
index 2712525..73ff802 100644
--- a/util/files.py
+++ b/util/files.py
@@ -8,7 +8,7 @@ from safetensors import safe_open
8 8
9 9
10def load_config(filename): 10def load_config(filename):
11 with open(filename, 'rt') as f: 11 with open(filename, "rt") as f:
12 config = json.load(f) 12 config = json.load(f)
13 13
14 args = config["args"] 14 args = config["args"]
@@ -19,11 +19,17 @@ def load_config(filename):
19 return args 19 return args
20 20
21 21
22def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): 22def load_embeddings_from_dir(
23 tokenizer: MultiCLIPTokenizer,
24 embeddings: ManagedCLIPTextEmbeddings,
25 embeddings_dir: Path,
26):
23 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 27 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
24 return [] 28 return [], []
25 29
26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] 30 filenames = [
31 filename for filename in embeddings_dir.iterdir() if filename.is_file()
32 ]
27 tokens = [filename.stem for filename in filenames] 33 tokens = [filename.stem for filename in filenames]
28 34
29 new_ids: list[list[int]] = [] 35 new_ids: list[list[int]] = []
@@ -39,7 +45,7 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC
39 45
40 embeddings.resize(len(tokenizer)) 46 embeddings.resize(len(tokenizer))
41 47
42 for (new_id, embeds) in zip(new_ids, new_embeds): 48 for new_id, embeds in zip(new_ids, new_embeds):
43 embeddings.add_embed(new_id, embeds) 49 embeddings.add_embed(new_id, embeds)
44 50
45 return tokens, new_ids 51 return tokens, new_ids