From d2105d96fdd18da035d2ad412e3fb6f579d5571a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Dec 2022 14:30:15 +0100 Subject: Fixed Textual Inversion --- train_dreambooth.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 51e881a..8cb6414 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -568,9 +568,16 @@ def main(): print(f"Training entire text encoder.") else: print(f"Training added text embeddings") - text_encoder.requires_grad_(False) + patch_trainable_embeddings(text_encoder, placeholder_token_id) + freeze_params(itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + text_encoder.text_model.embeddings.token_embedding.parameters(), + )) + prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: -- cgit v1.2.3-54-g00ecf