summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index e933c48..9d06c50 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -405,10 +405,8 @@ class Checkpointer(CheckpointerBase):
405 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 405 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
406 406
407 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 407 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
408 training_token_id = self.text_embeddings.id_mapping[placeholder_token_id]
409
410 # Save a checkpoint 408 # Save a checkpoint
411 learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] 409 learned_embeds = self.text_embeddings.trainable_embedding.weight[placeholder_token_id]
412 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 410 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
413 411
414 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 412 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)