diff options
Diffstat (limited to 'util')
| -rw-r--r-- | util/files.py | 16 | 
1 files changed, 11 insertions, 5 deletions
diff --git a/util/files.py b/util/files.py index 2712525..73ff802 100644 --- a/util/files.py +++ b/util/files.py  | |||
| @@ -8,7 +8,7 @@ from safetensors import safe_open | |||
| 8 | 8 | ||
| 9 | 9 | ||
| 10 | def load_config(filename): | 10 | def load_config(filename): | 
| 11 | with open(filename, 'rt') as f: | 11 | with open(filename, "rt") as f: | 
| 12 | config = json.load(f) | 12 | config = json.load(f) | 
| 13 | 13 | ||
| 14 | args = config["args"] | 14 | args = config["args"] | 
| @@ -19,11 +19,17 @@ def load_config(filename): | |||
| 19 | return args | 19 | return args | 
| 20 | 20 | ||
| 21 | 21 | ||
| 22 | def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): | 22 | def load_embeddings_from_dir( | 
| 23 | tokenizer: MultiCLIPTokenizer, | ||
| 24 | embeddings: ManagedCLIPTextEmbeddings, | ||
| 25 | embeddings_dir: Path, | ||
| 26 | ): | ||
| 23 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 27 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 
| 24 | return [] | 28 | return [], [] | 
| 25 | 29 | ||
| 26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] | 30 | filenames = [ | 
| 31 | filename for filename in embeddings_dir.iterdir() if filename.is_file() | ||
| 32 | ] | ||
| 27 | tokens = [filename.stem for filename in filenames] | 33 | tokens = [filename.stem for filename in filenames] | 
| 28 | 34 | ||
| 29 | new_ids: list[list[int]] = [] | 35 | new_ids: list[list[int]] = [] | 
| @@ -39,7 +45,7 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC | |||
| 39 | 45 | ||
| 40 | embeddings.resize(len(tokenizer)) | 46 | embeddings.resize(len(tokenizer)) | 
| 41 | 47 | ||
| 42 | for (new_id, embeds) in zip(new_ids, new_embeds): | 48 | for new_id, embeds in zip(new_ids, new_embeds): | 
| 43 | embeddings.add_embed(new_id, embeds) | 49 | embeddings.add_embed(new_id, embeds) | 
| 44 | 50 | ||
| 45 | return tokens, new_ids | 51 | return tokens, new_ids | 
