From 27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 16:26:22 +0200 Subject: Fixes --- util/files.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) (limited to 'util') diff --git a/util/files.py b/util/files.py index 2712525..73ff802 100644 --- a/util/files.py +++ b/util/files.py @@ -8,7 +8,7 @@ from safetensors import safe_open def load_config(filename): - with open(filename, 'rt') as f: + with open(filename, "rt") as f: config = json.load(f) args = config["args"] @@ -19,11 +19,17 @@ def load_config(filename): return args -def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): +def load_embeddings_from_dir( + tokenizer: MultiCLIPTokenizer, + embeddings: ManagedCLIPTextEmbeddings, + embeddings_dir: Path, +): if not embeddings_dir.exists() or not embeddings_dir.is_dir(): - return [] + return [], [] - filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] + filenames = [ + filename for filename in embeddings_dir.iterdir() if filename.is_file() + ] tokens = [filename.stem for filename in filenames] new_ids: list[list[int]] = [] @@ -39,7 +45,7 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC embeddings.resize(len(tokenizer)) - for (new_id, embeds) in zip(new_ids, new_embeds): + for new_id, embeds in zip(new_ids, new_embeds): embeddings.add_embed(new_id, embeds) return tokens, new_ids -- cgit v1.2.3-70-g09d2