diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/dreambooth.py b/dreambooth.py index 3f45754..96213d0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -629,7 +629,11 @@ def main(): | |||
629 | vae.requires_grad_(False) | 629 | vae.requires_grad_(False) |
630 | 630 | ||
631 | if args.embeddings_dir is not None: | 631 | if args.embeddings_dir is not None: |
632 | load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) | 632 | embeddings_dir = Path(args.embeddings_dir) |
633 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
634 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
635 | added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | ||
636 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | ||
633 | 637 | ||
634 | if len(args.placeholder_token) != 0: | 638 | if len(args.placeholder_token) != 0: |
635 | # Convert the initializer_token, placeholder_token to ids | 639 | # Convert the initializer_token, placeholder_token to ids |