diff options
author | Volpeon <git@volpeon.ink> | 2022-12-24 14:30:15 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-24 14:30:15 +0100 |
commit | d2105d96fdd18da035d2ad412e3fb6f579d5571a (patch) | |
tree | f6b5ff7f817875bcb086e88e7b4a9eebd537adbe /train_dreambooth.py | |
parent | Training update (diff) | |
download | textual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.tar.gz textual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.tar.bz2 textual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.zip |
Fixed Textual Inversion
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 51e881a..8cb6414 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -568,9 +568,16 @@ def main(): | |||
568 | print(f"Training entire text encoder.") | 568 | print(f"Training entire text encoder.") |
569 | else: | 569 | else: |
570 | print(f"Training added text embeddings") | 570 | print(f"Training added text embeddings") |
571 | text_encoder.requires_grad_(False) | 571 | |
572 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 572 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
573 | 573 | ||
574 | freeze_params(itertools.chain( | ||
575 | text_encoder.text_model.encoder.parameters(), | ||
576 | text_encoder.text_model.final_layer_norm.parameters(), | ||
577 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
578 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
579 | )) | ||
580 | |||
574 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 581 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
575 | 582 | ||
576 | if args.scale_lr: | 583 | if args.scale_lr: |