diff options
-rw-r--r-- | train_dreambooth.py | 15 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 11 |
2 files changed, 9 insertions, 17 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 939a8f3..ab3ed16 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -218,11 +218,6 @@ def parse_args(): | |||
218 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 218 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
219 | ) | 219 | ) |
220 | parser.add_argument( | 220 | parser.add_argument( |
221 | "--train_dir_embeddings", | ||
222 | action="store_true", | ||
223 | help="Train embeddings loaded from embeddings directory.", | ||
224 | ) | ||
225 | parser.add_argument( | ||
226 | "--collection", | 221 | "--collection", |
227 | type=str, | 222 | type=str, |
228 | nargs="*", | 223 | nargs="*", |
@@ -696,19 +691,13 @@ def main(): | |||
696 | tokenizer, embeddings, embeddings_dir | 691 | tokenizer, embeddings, embeddings_dir |
697 | ) | 692 | ) |
698 | 693 | ||
699 | placeholder_tokens = added_tokens | ||
700 | placeholder_token_ids = added_ids | ||
701 | |||
702 | print( | 694 | print( |
703 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | 695 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" |
704 | ) | 696 | ) |
705 | 697 | ||
706 | if args.train_dir_embeddings: | 698 | embeddings.persist() |
707 | print("Training embeddings from embeddings dir") | ||
708 | else: | ||
709 | embeddings.persist() | ||
710 | 699 | ||
711 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 700 | if len(args.placeholder_tokens) != 0: |
712 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 701 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
713 | tokenizer=tokenizer, | 702 | tokenizer=tokenizer, |
714 | embeddings=embeddings, | 703 | embeddings=embeddings, |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 43fe838..35cccbb 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -203,10 +203,13 @@ def dreambooth_prepare( | |||
203 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 203 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
204 | ) | 204 | ) |
205 | 205 | ||
206 | for layer in text_encoder.text_model.encoder.layers[ | 206 | if text_encoder_unfreeze_last_n_layers == 0: |
207 | : (-1 * text_encoder_unfreeze_last_n_layers) | 207 | text_encoder.text_model.encoder.requires_grad_(False) |
208 | ]: | 208 | elif text_encoder_unfreeze_last_n_layers > 0: |
209 | layer.requires_grad_(False) | 209 | for layer in text_encoder.text_model.encoder.layers[ |
210 | : (-1 * text_encoder_unfreeze_last_n_layers) | ||
211 | ]: | ||
212 | layer.requires_grad_(False) | ||
210 | 213 | ||
211 | text_encoder.text_model.embeddings.requires_grad_(False) | 214 | text_encoder.text_model.embeddings.requires_grad_(False) |
212 | 215 | ||