diff options
author | Volpeon <git@volpeon.ink> | 2022-12-14 09:43:45 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-14 09:43:45 +0100 |
commit | 279174a7a31f0fc6ed209e5b46901e50fe722c87 (patch) | |
tree | ec12ec9a66c5e6532aa0be08608c638283e090fb /dreambooth.py | |
parent | Unified loading of TI embeddings (diff) | |
download | textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.tar.gz textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.tar.bz2 textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.zip |
More generic datset filter
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/dreambooth.py b/dreambooth.py index 3f45754..96213d0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -629,7 +629,11 @@ def main(): | |||
629 | vae.requires_grad_(False) | 629 | vae.requires_grad_(False) |
630 | 630 | ||
631 | if args.embeddings_dir is not None: | 631 | if args.embeddings_dir is not None: |
632 | load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) | 632 | embeddings_dir = Path(args.embeddings_dir) |
633 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
634 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
635 | added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | ||
636 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | ||
633 | 637 | ||
634 | if len(args.placeholder_token) != 0: | 638 | if len(args.placeholder_token) != 0: |
635 | # Convert the initializer_token, placeholder_token to ids | 639 | # Convert the initializer_token, placeholder_token to ids |