summaryrefslogtreecommitdiffstats
path: root/common.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-14 09:43:45 +0100
committerVolpeon <git@volpeon.ink>2022-12-14 09:43:45 +0100
commit279174a7a31f0fc6ed209e5b46901e50fe722c87 (patch)
treeec12ec9a66c5e6532aa0be08608c638283e090fb /common.py
parentUnified loading of TI embeddings (diff)
downloadtextual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.tar.gz
textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.tar.bz2
textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.zip
More generic datset filter
Diffstat (limited to 'common.py')
-rw-r--r--common.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/common.py b/common.py
index 8d6b55d..7ffa77f 100644
--- a/common.py
+++ b/common.py
@@ -18,7 +18,7 @@ def load_text_embedding(embeddings, token_id, file):
18 18
19def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): 19def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path):
20 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 20 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
21 return 0 21 return []
22 22
23 files = [file for file in embeddings_dir.iterdir() if file.is_file()] 23 files = [file for file in embeddings_dir.iterdir() if file.is_file()]
24 24
@@ -33,4 +33,4 @@ def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel,
33 for (token_id, file) in zip(token_ids, files): 33 for (token_id, file) in zip(token_ids, files):
34 load_text_embedding(token_embeds, token_id, file) 34 load_text_embedding(token_embeds, token_id, file)
35 35
36 return added 36 return tokens