summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-24 14:30:15 +0100
committerVolpeon <git@volpeon.ink>2022-12-24 14:30:15 +0100
commitd2105d96fdd18da035d2ad412e3fb6f579d5571a (patch)
treef6b5ff7f817875bcb086e88e7b4a9eebd537adbe /train_dreambooth.py
parentTraining update (diff)
downloadtextual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.tar.gz
textual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.tar.bz2
textual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.zip
Fixed Textual Inversion
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 51e881a..8cb6414 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -568,9 +568,16 @@ def main():
568 print(f"Training entire text encoder.") 568 print(f"Training entire text encoder.")
569 else: 569 else:
570 print(f"Training added text embeddings") 570 print(f"Training added text embeddings")
571 text_encoder.requires_grad_(False) 571
572 patch_trainable_embeddings(text_encoder, placeholder_token_id) 572 patch_trainable_embeddings(text_encoder, placeholder_token_id)
573 573
574 freeze_params(itertools.chain(
575 text_encoder.text_model.encoder.parameters(),
576 text_encoder.text_model.final_layer_norm.parameters(),
577 text_encoder.text_model.embeddings.position_embedding.parameters(),
578 text_encoder.text_model.embeddings.token_embedding.parameters(),
579 ))
580
574 prompt_processor = PromptProcessor(tokenizer, text_encoder) 581 prompt_processor = PromptProcessor(tokenizer, text_encoder)
575 582
576 if args.scale_lr: 583 if args.scale_lr: