diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 51e881a..8cb6414 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -568,9 +568,16 @@ def main(): | |||
568 | print(f"Training entire text encoder.") | 568 | print(f"Training entire text encoder.") |
569 | else: | 569 | else: |
570 | print(f"Training added text embeddings") | 570 | print(f"Training added text embeddings") |
571 | text_encoder.requires_grad_(False) | 571 | |
572 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 572 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
573 | 573 | ||
574 | freeze_params(itertools.chain( | ||
575 | text_encoder.text_model.encoder.parameters(), | ||
576 | text_encoder.text_model.final_layer_norm.parameters(), | ||
577 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
578 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
579 | )) | ||
580 | |||
574 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 581 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
575 | 582 | ||
576 | if args.scale_lr: | 583 | if args.scale_lr: |