From cd544ace2bc204f7ba1c2f7ce3e1d2ed8bb3e114 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 12:12:24 +0100 Subject: Fix --- train_ti.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index e933c48..9d06c50 100644 --- a/train_ti.py +++ b/train_ti.py @@ -405,10 +405,8 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): - training_token_id = self.text_embeddings.id_mapping[placeholder_token_id] - # Save a checkpoint - learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] + learned_embeds = self.text_embeddings.trainable_embedding.weight[placeholder_token_id] learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) -- cgit v1.2.3-54-g00ecf