summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 3f45754..96213d0 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -629,7 +629,11 @@ def main():
629 vae.requires_grad_(False) 629 vae.requires_grad_(False)
630 630
631 if args.embeddings_dir is not None: 631 if args.embeddings_dir is not None:
632 load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) 632 embeddings_dir = Path(args.embeddings_dir)
633 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
634 raise ValueError("--embeddings_dir must point to an existing directory")
635 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
636 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
633 637
634 if len(args.placeholder_token) != 0: 638 if len(args.placeholder_token) != 0:
635 # Convert the initializer_token, placeholder_token to ids 639 # Convert the initializer_token, placeholder_token to ids