summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 12:12:24 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 12:12:24 +0100
commitcd544ace2bc204f7ba1c2f7ce3e1d2ed8bb3e114 (patch)
treec3c1faded8d3a3c4171524c7492ea6e14dab6718 /train_ti.py
parentSimplified trainable embedding code again (diff)
downloadtextual-inversion-diff-cd544ace2bc204f7ba1c2f7ce3e1d2ed8bb3e114.tar.gz
textual-inversion-diff-cd544ace2bc204f7ba1c2f7ce3e1d2ed8bb3e114.tar.bz2
textual-inversion-diff-cd544ace2bc204f7ba1c2f7ce3e1d2ed8bb3e114.zip
Fix
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index e933c48..9d06c50 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -405,10 +405,8 @@ class Checkpointer(CheckpointerBase):
405 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 405 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
406 406
407 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 407 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
408 training_token_id = self.text_embeddings.id_mapping[placeholder_token_id]
409
410 # Save a checkpoint 408 # Save a checkpoint
411 learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] 409 learned_embeds = self.text_embeddings.trainable_embedding.weight[placeholder_token_id]
412 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 410 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
413 411
414 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 412 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)