diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-23 12:12:24 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-23 12:12:24 +0100 |
| commit | cd544ace2bc204f7ba1c2f7ce3e1d2ed8bb3e114 (patch) | |
| tree | c3c1faded8d3a3c4171524c7492ea6e14dab6718 | |
| parent | Simplified trainable embedding code again (diff) | |
| download | textual-inversion-diff-cd544ace2bc204f7ba1c2f7ce3e1d2ed8bb3e114.tar.gz textual-inversion-diff-cd544ace2bc204f7ba1c2f7ce3e1d2ed8bb3e114.tar.bz2 textual-inversion-diff-cd544ace2bc204f7ba1c2f7ce3e1d2ed8bb3e114.zip | |
Fix
| -rw-r--r-- | train_ti.py | 4 |
1 files changed, 1 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index e933c48..9d06c50 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -405,10 +405,8 @@ class Checkpointer(CheckpointerBase): | |||
| 405 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 405 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 406 | 406 | ||
| 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
| 408 | training_token_id = self.text_embeddings.id_mapping[placeholder_token_id] | ||
| 409 | |||
| 410 | # Save a checkpoint | 408 | # Save a checkpoint |
| 411 | learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] | 409 | learned_embeds = self.text_embeddings.trainable_embedding.weight[placeholder_token_id] |
| 412 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 410 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
| 413 | 411 | ||
| 414 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 412 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
