From d2105d96fdd18da035d2ad412e3fb6f579d5571a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Dec 2022 14:30:15 +0100 Subject: Fixed Textual Inversion --- train_dreambooth.py | 9 ++++++++- train_ti.py | 10 ++++++++-- training/ti.py | 15 +++++---------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 51e881a..8cb6414 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -568,9 +568,16 @@ def main(): print(f"Training entire text encoder.") else: print(f"Training added text embeddings") - text_encoder.requires_grad_(False) + patch_trainable_embeddings(text_encoder, placeholder_token_id) + freeze_params(itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + text_encoder.text_model.embeddings.token_embedding.parameters(), + )) + prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: diff --git a/train_ti.py b/train_ti.py index a12b889..5f37d54 100644 --- a/train_ti.py +++ b/train_ti.py @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from training.ti import patch_trainable_embeddings -from training.util import AverageMeter, CheckpointerBase, save_args +from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -515,10 +515,16 @@ def main(): vae.requires_grad_(False) unet.requires_grad_(False) - text_encoder.requires_grad_(False) patch_trainable_embeddings(text_encoder, placeholder_token_id) + freeze_params(itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + text_encoder.text_model.embeddings.token_embedding.parameters(), + )) + prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: diff --git a/training/ti.py b/training/ti.py index 8b2fdd6..dc33e5e 100644 --- a/training/ti.py +++ b/training/ti.py @@ -8,26 +8,21 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): - text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) - - text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding - text_embeddings.token_embedding.weight.requires_grad = False - - text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding - text_embeddings.position_embedding.weight.requires_grad = False - + text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids) text_encoder.text_model.embeddings = text_embeddings class TrainableEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, new_ids: list[int]): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): super().__init__(config) self.train_indices = torch.tensor(new_ids) self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) + + self.token_embedding = embeddings.token_embedding + self.position_embedding = embeddings.position_embedding self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() - self.trainable_embedding.weight.requires_grad = True def forward( self, -- cgit v1.2.3-70-g09d2