From d2105d96fdd18da035d2ad412e3fb6f579d5571a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Dec 2022 14:30:15 +0100 Subject: Fixed Textual Inversion --- train_ti.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index a12b889..5f37d54 100644 --- a/train_ti.py +++ b/train_ti.py @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from training.ti import patch_trainable_embeddings -from training.util import AverageMeter, CheckpointerBase, save_args +from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -515,10 +515,16 @@ def main(): vae.requires_grad_(False) unet.requires_grad_(False) - text_encoder.requires_grad_(False) patch_trainable_embeddings(text_encoder, placeholder_token_id) + freeze_params(itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + text_encoder.text_model.embeddings.token_embedding.parameters(), + )) + prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: -- cgit v1.2.3-54-g00ecf