diff options
-rw-r--r-- | train_ti.py | 4 |
1 files changed, 1 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index e933c48..9d06c50 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -405,10 +405,8 @@ class Checkpointer(CheckpointerBase): | |||
405 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 405 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
406 | 406 | ||
407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
408 | training_token_id = self.text_embeddings.id_mapping[placeholder_token_id] | ||
409 | |||
410 | # Save a checkpoint | 408 | # Save a checkpoint |
411 | learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] | 409 | learned_embeds = self.text_embeddings.trainable_embedding.weight[placeholder_token_id] |
412 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 410 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
413 | 411 | ||
414 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 412 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |