From ee9a2777c15d4ceea7ef40802b9a21881f6428a8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Dec 2022 21:15:24 +0100 Subject: Fixed Textual Inversion --- models/clip/prompt.py | 6 +++--- train_ti.py | 12 +++++++----- training/ti.py | 9 +++------ 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 9b427a0..da33ecf 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py @@ -27,10 +27,10 @@ class PromptProcessor(): def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): prompts = input_ids.shape[0] - input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) + input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) if attention_mask is not None: - attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) + attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] - text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) + text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) return text_embeddings diff --git a/train_ti.py b/train_ti.py index bb51dc2..e933c48 100644 --- a/train_ti.py +++ b/train_ti.py @@ -365,6 +365,7 @@ class Checkpointer(CheckpointerBase): tokenizer, text_encoder, scheduler, + text_embeddings, instance_identifier, placeholder_token, placeholder_token_id, @@ -392,6 +393,7 @@ class Checkpointer(CheckpointerBase): self.tokenizer = tokenizer self.text_encoder = text_encoder self.scheduler = scheduler + self.text_embeddings = text_embeddings @torch.no_grad() def checkpoint(self, step, postfix): @@ -403,8 +405,10 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): + training_token_id = self.text_embeddings.id_mapping[placeholder_token_id] + # Save a checkpoint - learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] + learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) @@ -543,7 +547,7 @@ def main(): # Initialize the optimizer optimizer = optimizer_class( - text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings + text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -741,6 +745,7 @@ def main(): tokenizer=tokenizer, text_encoder=text_encoder, scheduler=checkpoint_scheduler, + text_embeddings=text_embeddings, instance_identifier=args.instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, @@ -774,7 +779,6 @@ def main(): local_progress_bar.reset() text_encoder.train() - train_loss = 0.0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): @@ -834,8 +838,6 @@ def main(): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - text_embeddings.save() - avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) diff --git a/training/ti.py b/training/ti.py index a5fd8e4..2efd2f2 100644 --- a/training/ti.py +++ b/training/ti.py @@ -19,8 +19,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings): def __init__(self, config: CLIPTextConfig, new_ids: list[int]): super().__init__(config) - self.token_embedding.requires_grad_(False) - self.position_embedding.requires_grad_(False) + self.token_embedding.weight.requires_grad = False + self.position_embedding.weight.requires_grad = False self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} @@ -28,6 +28,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings): self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) + self.trainable_embedding.weight.requires_grad = True def forward( self, @@ -64,7 +65,3 @@ class TrainableEmbeddings(CLIPTextEmbeddings): embeddings = inputs_embeds + position_embeddings return embeddings - - @torch.no_grad() - def save(self): - self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data -- cgit v1.2.3-54-g00ecf