summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py15
-rw-r--r--training/strategy/dreambooth.py11
2 files changed, 9 insertions, 17 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 939a8f3..ab3ed16 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -218,11 +218,6 @@ def parse_args():
218 help="The embeddings directory where Textual Inversion embeddings are stored.", 218 help="The embeddings directory where Textual Inversion embeddings are stored.",
219 ) 219 )
220 parser.add_argument( 220 parser.add_argument(
221 "--train_dir_embeddings",
222 action="store_true",
223 help="Train embeddings loaded from embeddings directory.",
224 )
225 parser.add_argument(
226 "--collection", 221 "--collection",
227 type=str, 222 type=str,
228 nargs="*", 223 nargs="*",
@@ -696,19 +691,13 @@ def main():
696 tokenizer, embeddings, embeddings_dir 691 tokenizer, embeddings, embeddings_dir
697 ) 692 )
698 693
699 placeholder_tokens = added_tokens
700 placeholder_token_ids = added_ids
701
702 print( 694 print(
703 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" 695 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
704 ) 696 )
705 697
706 if args.train_dir_embeddings: 698 embeddings.persist()
707 print("Training embeddings from embeddings dir")
708 else:
709 embeddings.persist()
710 699
711 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: 700 if len(args.placeholder_tokens) != 0:
712 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 701 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
713 tokenizer=tokenizer, 702 tokenizer=tokenizer,
714 embeddings=embeddings, 703 embeddings=embeddings,
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 43fe838..35cccbb 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -203,10 +203,13 @@ def dreambooth_prepare(
203 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 203 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
204 ) 204 )
205 205
206 for layer in text_encoder.text_model.encoder.layers[ 206 if text_encoder_unfreeze_last_n_layers == 0:
207 : (-1 * text_encoder_unfreeze_last_n_layers) 207 text_encoder.text_model.encoder.requires_grad_(False)
208 ]: 208 elif text_encoder_unfreeze_last_n_layers > 0:
209 layer.requires_grad_(False) 209 for layer in text_encoder.text_model.encoder.layers[
210 : (-1 * text_encoder_unfreeze_last_n_layers)
211 ]:
212 layer.requires_grad_(False)
210 213
211 text_encoder.text_model.embeddings.requires_grad_(False) 214 text_encoder.text_model.embeddings.requires_grad_(False)
212 215