diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-31 14:07:44 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-31 14:07:44 +0100 |
| commit | dc463a6b8ef120b7a0643569b66f9109ed38c652 (patch) | |
| tree | ae742a988b8541009a980c8b2f719696f9d7df27 /train_ti.py | |
| parent | Fixes (diff) | |
| download | textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.gz textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.bz2 textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.zip | |
Simplified multi-vector embedding code
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 3a5cfed..3776eb2 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -438,7 +438,7 @@ class Checkpointer(CheckpointerBase): | |||
| 438 | 438 | ||
| 439 | for new_token in self.new_tokens: | 439 | for new_token in self.new_tokens: |
| 440 | text_encoder.text_model.embeddings.save_embed( | 440 | text_encoder.text_model.embeddings.save_embed( |
| 441 | new_token.multi_ids, | 441 | new_token.ids, |
| 442 | checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") | 442 | checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") |
| 443 | ) | 443 | ) |
| 444 | 444 | ||
| @@ -537,8 +537,7 @@ def main(): | |||
| 537 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 537 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
| 538 | 538 | ||
| 539 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): | 539 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
| 540 | embeddings.add_embed(new_token.placeholder_id) | 540 | embeddings.add_embed(new_token.ids, init_ids) |
| 541 | embeddings.add_embed(new_token.multi_ids, init_ids) | ||
| 542 | 541 | ||
| 543 | print(f"Added {len(new_tokens)} new tokens.") | 542 | print(f"Added {len(new_tokens)} new tokens.") |
| 544 | 543 | ||
