diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-31 14:07:44 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-31 14:07:44 +0100 |
| commit | dc463a6b8ef120b7a0643569b66f9109ed38c652 (patch) | |
| tree | ae742a988b8541009a980c8b2f719696f9d7df27 | |
| parent | Fixes (diff) | |
| download | textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.gz textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.bz2 textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.zip | |
Simplified multi-vector embedding code
| -rw-r--r-- | common.py | 3 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 23 | ||||
| -rw-r--r-- | train_ti.py | 5 |
3 files changed, 14 insertions, 17 deletions
| @@ -32,7 +32,6 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC | |||
| 32 | embed = file.get_tensor("embed") | 32 | embed = file.get_tensor("embed") |
| 33 | 33 | ||
| 34 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | 34 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) |
| 35 | embeddings.add_embed(added.placeholder_id) | 35 | embeddings.add_embed(added.ids, embed) |
| 36 | embeddings.add_embed(added.multi_ids, embed) | ||
| 37 | 36 | ||
| 38 | return tokens | 37 | return tokens |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 78871db..7e08287 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -8,8 +8,8 @@ from transformers import CLIPTokenizer | |||
| 8 | 8 | ||
| 9 | class MultiCLIPTokenizerItem(NamedTuple): | 9 | class MultiCLIPTokenizerItem(NamedTuple): |
| 10 | token: str | 10 | token: str |
| 11 | placeholder_id: int | 11 | meta_id: int |
| 12 | multi_ids: list[int] | 12 | ids: list[int] |
| 13 | 13 | ||
| 14 | 14 | ||
| 15 | class MultiCLIPTokenizer(CLIPTokenizer): | 15 | class MultiCLIPTokenizer(CLIPTokenizer): |
| @@ -30,20 +30,19 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 30 | if isinstance(num_vectors, list): | 30 | if isinstance(num_vectors, list): |
| 31 | raise ValueError("Expected num_vectors to be int for single token") | 31 | raise ValueError("Expected num_vectors to be int for single token") |
| 32 | 32 | ||
| 33 | super().add_tokens(new_tokens) | 33 | if num_vectors < 1: |
| 34 | raise ValueError("Expected num_vectors to be >= 1") | ||
| 34 | 35 | ||
| 35 | if num_vectors == 1: | 36 | multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] |
| 36 | multi_token = [new_tokens] | ||
| 37 | else: | ||
| 38 | multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] | ||
| 39 | super().add_tokens(multi_token) | ||
| 40 | 37 | ||
| 41 | meta_id = super().convert_tokens_to_ids(new_tokens) | 38 | super().add_tokens(multi_token) |
| 42 | multi_ids = super().convert_tokens_to_ids(multi_token) | ||
| 43 | 39 | ||
| 44 | self.token_map[meta_id] = multi_ids | 40 | ids = super().convert_tokens_to_ids(multi_token) |
| 41 | meta_id = ids[0] | ||
| 45 | 42 | ||
| 46 | return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) | 43 | self.token_map[meta_id] = ids |
| 44 | |||
| 45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) | ||
| 47 | 46 | ||
| 48 | def encode(self, *args, vector_shuffle=True, **kwargs): | 47 | def encode(self, *args, vector_shuffle=True, **kwargs): |
| 49 | ids = super().encode(*args, **kwargs) | 48 | ids = super().encode(*args, **kwargs) |
diff --git a/train_ti.py b/train_ti.py index 3a5cfed..3776eb2 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -438,7 +438,7 @@ class Checkpointer(CheckpointerBase): | |||
| 438 | 438 | ||
| 439 | for new_token in self.new_tokens: | 439 | for new_token in self.new_tokens: |
| 440 | text_encoder.text_model.embeddings.save_embed( | 440 | text_encoder.text_model.embeddings.save_embed( |
| 441 | new_token.multi_ids, | 441 | new_token.ids, |
| 442 | checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") | 442 | checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") |
| 443 | ) | 443 | ) |
| 444 | 444 | ||
| @@ -537,8 +537,7 @@ def main(): | |||
| 537 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 537 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
| 538 | 538 | ||
| 539 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): | 539 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
| 540 | embeddings.add_embed(new_token.placeholder_id) | 540 | embeddings.add_embed(new_token.ids, init_ids) |
| 541 | embeddings.add_embed(new_token.multi_ids, init_ids) | ||
| 542 | 541 | ||
| 543 | print(f"Added {len(new_tokens)} new tokens.") | 542 | print(f"Added {len(new_tokens)} new tokens.") |
| 544 | 543 | ||
