From ee9a2777c15d4ceea7ef40802b9a21881f6428a8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Dec 2022 21:15:24 +0100 Subject: Fixed Textual Inversion --- train_ti.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index bb51dc2..e933c48 100644 --- a/train_ti.py +++ b/train_ti.py @@ -365,6 +365,7 @@ class Checkpointer(CheckpointerBase): tokenizer, text_encoder, scheduler, + text_embeddings, instance_identifier, placeholder_token, placeholder_token_id, @@ -392,6 +393,7 @@ class Checkpointer(CheckpointerBase): self.tokenizer = tokenizer self.text_encoder = text_encoder self.scheduler = scheduler + self.text_embeddings = text_embeddings @torch.no_grad() def checkpoint(self, step, postfix): @@ -403,8 +405,10 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): + training_token_id = self.text_embeddings.id_mapping[placeholder_token_id] + # Save a checkpoint - learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] + learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) @@ -543,7 +547,7 @@ def main(): # Initialize the optimizer optimizer = optimizer_class( - text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings + text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -741,6 +745,7 @@ def main(): tokenizer=tokenizer, text_encoder=text_encoder, scheduler=checkpoint_scheduler, + text_embeddings=text_embeddings, instance_identifier=args.instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, @@ -774,7 +779,6 @@ def main(): local_progress_bar.reset() text_encoder.train() - train_loss = 0.0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): @@ -834,8 +838,6 @@ def main(): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - text_embeddings.save() - avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) -- cgit v1.2.3-54-g00ecf