From fd691d762820863c5236a189a752ba4f985a961b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Dec 2022 16:37:47 +0100 Subject: Improved Textual Inversion: Completely exclude untrained embeddings from training --- train_ti.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 198cf37..bb51dc2 100644 --- a/train_ti.py +++ b/train_ti.py @@ -24,7 +24,8 @@ from common import load_text_embeddings, load_text_embedding from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule -from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args +from training.ti import patch_trainable_embeddings +from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -512,24 +513,14 @@ def main(): for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) - original_token_embeds = token_embeds.clone().to(accelerator.device) - initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): token_embeds[token_id] = embeddings - index_fixed_tokens = torch.arange(len(tokenizer)) - index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] + vae.requires_grad_(False) + unet.requires_grad_(False) - # Freeze vae and unet - freeze_params(vae.parameters()) - freeze_params(unet.parameters()) - # Freeze all parameters except for the token embeddings in text encoder - freeze_params(itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - )) + text_embeddings = patch_trainable_embeddings(text_encoder, placeholder_token_id) prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -843,10 +834,7 @@ def main(): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - # Let's make sure we don't update any embedding weights besides the newly added token - with torch.no_grad(): - text_encoder.get_input_embeddings( - ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] + text_embeddings.save() avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) -- cgit v1.2.3-54-g00ecf