From 279174a7a31f0fc6ed209e5b46901e50fe722c87 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 14 Dec 2022 09:43:45 +0100 Subject: More generic datset filter --- dreambooth.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 3f45754..96213d0 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -629,7 +629,11 @@ def main(): vae.requires_grad_(False) if args.embeddings_dir is not None: - load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") if len(args.placeholder_token) != 0: # Convert the initializer_token, placeholder_token to ids -- cgit v1.2.3-54-g00ecf