diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 15 |
1 files changed, 2 insertions, 13 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 939a8f3..ab3ed16 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -218,11 +218,6 @@ def parse_args(): | |||
218 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 218 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
219 | ) | 219 | ) |
220 | parser.add_argument( | 220 | parser.add_argument( |
221 | "--train_dir_embeddings", | ||
222 | action="store_true", | ||
223 | help="Train embeddings loaded from embeddings directory.", | ||
224 | ) | ||
225 | parser.add_argument( | ||
226 | "--collection", | 221 | "--collection", |
227 | type=str, | 222 | type=str, |
228 | nargs="*", | 223 | nargs="*", |
@@ -696,19 +691,13 @@ def main(): | |||
696 | tokenizer, embeddings, embeddings_dir | 691 | tokenizer, embeddings, embeddings_dir |
697 | ) | 692 | ) |
698 | 693 | ||
699 | placeholder_tokens = added_tokens | ||
700 | placeholder_token_ids = added_ids | ||
701 | |||
702 | print( | 694 | print( |
703 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | 695 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" |
704 | ) | 696 | ) |
705 | 697 | ||
706 | if args.train_dir_embeddings: | 698 | embeddings.persist() |
707 | print("Training embeddings from embeddings dir") | ||
708 | else: | ||
709 | embeddings.persist() | ||
710 | 699 | ||
711 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 700 | if len(args.placeholder_tokens) != 0: |
712 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 701 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
713 | tokenizer=tokenizer, | 702 | tokenizer=tokenizer, |
714 | embeddings=embeddings, | 703 | embeddings=embeddings, |