diff options
author | Volpeon <git@volpeon.ink> | 2022-12-31 14:07:44 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-31 14:07:44 +0100 |
commit | dc463a6b8ef120b7a0643569b66f9109ed38c652 (patch) | |
tree | ae742a988b8541009a980c8b2f719696f9d7df27 | |
parent | Fixes (diff) | |
download | textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.gz textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.bz2 textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.zip |
Simplified multi-vector embedding code
-rw-r--r-- | common.py | 3 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 23 | ||||
-rw-r--r-- | train_ti.py | 5 |
3 files changed, 14 insertions, 17 deletions
@@ -32,7 +32,6 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC | |||
32 | embed = file.get_tensor("embed") | 32 | embed = file.get_tensor("embed") |
33 | 33 | ||
34 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | 34 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) |
35 | embeddings.add_embed(added.placeholder_id) | 35 | embeddings.add_embed(added.ids, embed) |
36 | embeddings.add_embed(added.multi_ids, embed) | ||
37 | 36 | ||
38 | return tokens | 37 | return tokens |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 78871db..7e08287 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -8,8 +8,8 @@ from transformers import CLIPTokenizer | |||
8 | 8 | ||
9 | class MultiCLIPTokenizerItem(NamedTuple): | 9 | class MultiCLIPTokenizerItem(NamedTuple): |
10 | token: str | 10 | token: str |
11 | placeholder_id: int | 11 | meta_id: int |
12 | multi_ids: list[int] | 12 | ids: list[int] |
13 | 13 | ||
14 | 14 | ||
15 | class MultiCLIPTokenizer(CLIPTokenizer): | 15 | class MultiCLIPTokenizer(CLIPTokenizer): |
@@ -30,20 +30,19 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
30 | if isinstance(num_vectors, list): | 30 | if isinstance(num_vectors, list): |
31 | raise ValueError("Expected num_vectors to be int for single token") | 31 | raise ValueError("Expected num_vectors to be int for single token") |
32 | 32 | ||
33 | super().add_tokens(new_tokens) | 33 | if num_vectors < 1: |
34 | raise ValueError("Expected num_vectors to be >= 1") | ||
34 | 35 | ||
35 | if num_vectors == 1: | 36 | multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)] |
36 | multi_token = [new_tokens] | ||
37 | else: | ||
38 | multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] | ||
39 | super().add_tokens(multi_token) | ||
40 | 37 | ||
41 | meta_id = super().convert_tokens_to_ids(new_tokens) | 38 | super().add_tokens(multi_token) |
42 | multi_ids = super().convert_tokens_to_ids(multi_token) | ||
43 | 39 | ||
44 | self.token_map[meta_id] = multi_ids | 40 | ids = super().convert_tokens_to_ids(multi_token) |
41 | meta_id = ids[0] | ||
45 | 42 | ||
46 | return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) | 43 | self.token_map[meta_id] = ids |
44 | |||
45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) | ||
47 | 46 | ||
48 | def encode(self, *args, vector_shuffle=True, **kwargs): | 47 | def encode(self, *args, vector_shuffle=True, **kwargs): |
49 | ids = super().encode(*args, **kwargs) | 48 | ids = super().encode(*args, **kwargs) |
diff --git a/train_ti.py b/train_ti.py index 3a5cfed..3776eb2 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -438,7 +438,7 @@ class Checkpointer(CheckpointerBase): | |||
438 | 438 | ||
439 | for new_token in self.new_tokens: | 439 | for new_token in self.new_tokens: |
440 | text_encoder.text_model.embeddings.save_embed( | 440 | text_encoder.text_model.embeddings.save_embed( |
441 | new_token.multi_ids, | 441 | new_token.ids, |
442 | checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") | 442 | checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") |
443 | ) | 443 | ) |
444 | 444 | ||
@@ -537,8 +537,7 @@ def main(): | |||
537 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 537 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
538 | 538 | ||
539 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): | 539 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
540 | embeddings.add_embed(new_token.placeholder_id) | 540 | embeddings.add_embed(new_token.ids, init_ids) |
541 | embeddings.add_embed(new_token.multi_ids, init_ids) | ||
542 | 541 | ||
543 | print(f"Added {len(new_tokens)} new tokens.") | 542 | print(f"Added {len(new_tokens)} new tokens.") |
544 | 543 | ||