summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py15
1 files changed, 2 insertions, 13 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 939a8f3..ab3ed16 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -218,11 +218,6 @@ def parse_args():
218 help="The embeddings directory where Textual Inversion embeddings are stored.", 218 help="The embeddings directory where Textual Inversion embeddings are stored.",
219 ) 219 )
220 parser.add_argument( 220 parser.add_argument(
221 "--train_dir_embeddings",
222 action="store_true",
223 help="Train embeddings loaded from embeddings directory.",
224 )
225 parser.add_argument(
226 "--collection", 221 "--collection",
227 type=str, 222 type=str,
228 nargs="*", 223 nargs="*",
@@ -696,19 +691,13 @@ def main():
696 tokenizer, embeddings, embeddings_dir 691 tokenizer, embeddings, embeddings_dir
697 ) 692 )
698 693
699 placeholder_tokens = added_tokens
700 placeholder_token_ids = added_ids
701
702 print( 694 print(
703 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" 695 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
704 ) 696 )
705 697
706 if args.train_dir_embeddings: 698 embeddings.persist()
707 print("Training embeddings from embeddings dir")
708 else:
709 embeddings.persist()
710 699
711 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: 700 if len(args.placeholder_tokens) != 0:
712 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 701 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
713 tokenizer=tokenizer, 702 tokenizer=tokenizer,
714 embeddings=embeddings, 703 embeddings=embeddings,