diff options
Diffstat (limited to 'util')
-rw-r--r-- | util/files.py | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/util/files.py b/util/files.py index 2712525..73ff802 100644 --- a/util/files.py +++ b/util/files.py | |||
@@ -8,7 +8,7 @@ from safetensors import safe_open | |||
8 | 8 | ||
9 | 9 | ||
10 | def load_config(filename): | 10 | def load_config(filename): |
11 | with open(filename, 'rt') as f: | 11 | with open(filename, "rt") as f: |
12 | config = json.load(f) | 12 | config = json.load(f) |
13 | 13 | ||
14 | args = config["args"] | 14 | args = config["args"] |
@@ -19,11 +19,17 @@ def load_config(filename): | |||
19 | return args | 19 | return args |
20 | 20 | ||
21 | 21 | ||
22 | def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): | 22 | def load_embeddings_from_dir( |
23 | tokenizer: MultiCLIPTokenizer, | ||
24 | embeddings: ManagedCLIPTextEmbeddings, | ||
25 | embeddings_dir: Path, | ||
26 | ): | ||
23 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 27 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
24 | return [] | 28 | return [], [] |
25 | 29 | ||
26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] | 30 | filenames = [ |
31 | filename for filename in embeddings_dir.iterdir() if filename.is_file() | ||
32 | ] | ||
27 | tokens = [filename.stem for filename in filenames] | 33 | tokens = [filename.stem for filename in filenames] |
28 | 34 | ||
29 | new_ids: list[list[int]] = [] | 35 | new_ids: list[list[int]] = [] |
@@ -39,7 +45,7 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC | |||
39 | 45 | ||
40 | embeddings.resize(len(tokenizer)) | 46 | embeddings.resize(len(tokenizer)) |
41 | 47 | ||
42 | for (new_id, embeds) in zip(new_ids, new_embeds): | 48 | for new_id, embeds in zip(new_ids, new_embeds): |
43 | embeddings.add_embed(new_id, embeds) | 49 | embeddings.add_embed(new_id, embeds) |
44 | 50 | ||
45 | return tokens, new_ids | 51 | return tokens, new_ids |