From fd691d762820863c5236a189a752ba4f985a961b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Dec 2022 16:37:47 +0100 Subject: Improved Textual Inversion: Completely exclude untrained embeddings from training --- train_ti.py | 24 +++++--------------- training/ti.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 18 deletions(-) create mode 100644 training/ti.py diff --git a/train_ti.py b/train_ti.py index 198cf37..bb51dc2 100644 --- a/train_ti.py +++ b/train_ti.py @@ -24,7 +24,8 @@ from common import load_text_embeddings, load_text_embedding from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule -from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args +from training.ti import patch_trainable_embeddings +from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -512,24 +513,14 @@ def main(): for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) - original_token_embeds = token_embeds.clone().to(accelerator.device) - initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): token_embeds[token_id] = embeddings - index_fixed_tokens = torch.arange(len(tokenizer)) - index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] + vae.requires_grad_(False) + unet.requires_grad_(False) - # Freeze vae and unet - freeze_params(vae.parameters()) - freeze_params(unet.parameters()) - # Freeze all parameters except for the token embeddings in text encoder - freeze_params(itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - )) + text_embeddings = patch_trainable_embeddings(text_encoder, placeholder_token_id) prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -843,10 +834,7 @@ def main(): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - # Let's make sure we don't update any embedding weights besides the newly added token - with torch.no_grad(): - text_encoder.get_input_embeddings( - ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] + text_embeddings.save() avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) diff --git a/training/ti.py b/training/ti.py new file mode 100644 index 0000000..a5fd8e4 --- /dev/null +++ b/training/ti.py @@ -0,0 +1,70 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from transformers.models.clip import CLIPTextModel, CLIPTextConfig +from transformers.models.clip.modeling_clip import CLIPTextEmbeddings + + +def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): + text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) + text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight + text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight + text_encoder.text_model.embeddings = text_embeddings + return text_embeddings + + +class TrainableEmbeddings(CLIPTextEmbeddings): + def __init__(self, config: CLIPTextConfig, new_ids: list[int]): + super().__init__(config) + + self.token_embedding.requires_grad_(False) + self.position_embedding.requires_grad_(False) + + self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} + + indices = torch.arange(self.token_embedding.num_embeddings) + self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] + + self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + mask = torch.isin( + input_ids, + self.train_indices.to(input_ids.device) + ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim) + + trainable_input_ids = torch.tensor([ + [ + self.id_mapping[id] if id in self.id_mapping else 0 + for id in batch + ] + for batch in input_ids + ], device=input_ids.device) + + inputs_embeds = torch.where( + mask, + self.trainable_embedding(trainable_input_ids), + self.token_embedding(input_ids) + ) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + @torch.no_grad() + def save(self): + self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data -- cgit v1.2.3-54-g00ecf