From fd691d762820863c5236a189a752ba4f985a961b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Dec 2022 16:37:47 +0100 Subject: Improved Textual Inversion: Completely exclude untrained embeddings from training --- training/ti.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 training/ti.py (limited to 'training') diff --git a/training/ti.py b/training/ti.py new file mode 100644 index 0000000..a5fd8e4 --- /dev/null +++ b/training/ti.py @@ -0,0 +1,70 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from transformers.models.clip import CLIPTextModel, CLIPTextConfig +from transformers.models.clip.modeling_clip import CLIPTextEmbeddings + + +def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): + text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) + text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight + text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight + text_encoder.text_model.embeddings = text_embeddings + return text_embeddings + + +class TrainableEmbeddings(CLIPTextEmbeddings): + def __init__(self, config: CLIPTextConfig, new_ids: list[int]): + super().__init__(config) + + self.token_embedding.requires_grad_(False) + self.position_embedding.requires_grad_(False) + + self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} + + indices = torch.arange(self.token_embedding.num_embeddings) + self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] + + self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + mask = torch.isin( + input_ids, + self.train_indices.to(input_ids.device) + ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim) + + trainable_input_ids = torch.tensor([ + [ + self.id_mapping[id] if id in self.id_mapping else 0 + for id in batch + ] + for batch in input_ids + ], device=input_ids.device) + + inputs_embeds = torch.where( + mask, + self.trainable_embedding(trainable_input_ids), + self.token_embedding(input_ids) + ) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + @torch.no_grad() + def save(self): + self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data -- cgit v1.2.3-70-g09d2