diff options
| -rw-r--r-- | train_ti.py | 4 |
1 files changed, 1 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index e933c48..9d06c50 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -405,10 +405,8 @@ class Checkpointer(CheckpointerBase): | |||
| 405 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 405 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 406 | 406 | ||
| 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
| 408 | training_token_id = self.text_embeddings.id_mapping[placeholder_token_id] | ||
| 409 | |||
| 410 | # Save a checkpoint | 408 | # Save a checkpoint |
| 411 | learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] | 409 | learned_embeds = self.text_embeddings.trainable_embedding.weight[placeholder_token_id] |
| 412 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 410 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
| 413 | 411 | ||
| 414 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 412 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
