diff options
| -rw-r--r-- | train_dreambooth.py | 15 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 11 |
2 files changed, 9 insertions, 17 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 939a8f3..ab3ed16 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -218,11 +218,6 @@ def parse_args(): | |||
| 218 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 218 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 219 | ) | 219 | ) |
| 220 | parser.add_argument( | 220 | parser.add_argument( |
| 221 | "--train_dir_embeddings", | ||
| 222 | action="store_true", | ||
| 223 | help="Train embeddings loaded from embeddings directory.", | ||
| 224 | ) | ||
| 225 | parser.add_argument( | ||
| 226 | "--collection", | 221 | "--collection", |
| 227 | type=str, | 222 | type=str, |
| 228 | nargs="*", | 223 | nargs="*", |
| @@ -696,19 +691,13 @@ def main(): | |||
| 696 | tokenizer, embeddings, embeddings_dir | 691 | tokenizer, embeddings, embeddings_dir |
| 697 | ) | 692 | ) |
| 698 | 693 | ||
| 699 | placeholder_tokens = added_tokens | ||
| 700 | placeholder_token_ids = added_ids | ||
| 701 | |||
| 702 | print( | 694 | print( |
| 703 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | 695 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" |
| 704 | ) | 696 | ) |
| 705 | 697 | ||
| 706 | if args.train_dir_embeddings: | 698 | embeddings.persist() |
| 707 | print("Training embeddings from embeddings dir") | ||
| 708 | else: | ||
| 709 | embeddings.persist() | ||
| 710 | 699 | ||
| 711 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 700 | if len(args.placeholder_tokens) != 0: |
| 712 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 701 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 713 | tokenizer=tokenizer, | 702 | tokenizer=tokenizer, |
| 714 | embeddings=embeddings, | 703 | embeddings=embeddings, |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 43fe838..35cccbb 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -203,10 +203,13 @@ def dreambooth_prepare( | |||
| 203 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 203 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 204 | ) | 204 | ) |
| 205 | 205 | ||
| 206 | for layer in text_encoder.text_model.encoder.layers[ | 206 | if text_encoder_unfreeze_last_n_layers == 0: |
| 207 | : (-1 * text_encoder_unfreeze_last_n_layers) | 207 | text_encoder.text_model.encoder.requires_grad_(False) |
| 208 | ]: | 208 | elif text_encoder_unfreeze_last_n_layers > 0: |
| 209 | layer.requires_grad_(False) | 209 | for layer in text_encoder.text_model.encoder.layers[ |
| 210 | : (-1 * text_encoder_unfreeze_last_n_layers) | ||
| 211 | ]: | ||
| 212 | layer.requires_grad_(False) | ||
| 210 | 213 | ||
| 211 | text_encoder.text_model.embeddings.requires_grad_(False) | 214 | text_encoder.text_model.embeddings.requires_grad_(False) |
| 212 | 215 | ||
