summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 14:07:44 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 14:07:44 +0100
commitdc463a6b8ef120b7a0643569b66f9109ed38c652 (patch)
treeae742a988b8541009a980c8b2f719696f9d7df27
parentFixes (diff)
downloadtextual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.gz
textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.bz2
textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.zip
Simplified multi-vector embedding code
-rw-r--r--common.py3
-rw-r--r--models/clip/tokenizer.py23
-rw-r--r--train_ti.py5
3 files changed, 14 insertions, 17 deletions
diff --git a/common.py b/common.py
index e8d3ac1..1e7f4b9 100644
--- a/common.py
+++ b/common.py
@@ -32,7 +32,6 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC
32 embed = file.get_tensor("embed") 32 embed = file.get_tensor("embed")
33 33
34 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) 34 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
35 embeddings.add_embed(added.placeholder_id) 35 embeddings.add_embed(added.ids, embed)
36 embeddings.add_embed(added.multi_ids, embed)
37 36
38 return tokens 37 return tokens
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 78871db..7e08287 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -8,8 +8,8 @@ from transformers import CLIPTokenizer
8 8
9class MultiCLIPTokenizerItem(NamedTuple): 9class MultiCLIPTokenizerItem(NamedTuple):
10 token: str 10 token: str
11 placeholder_id: int 11 meta_id: int
12 multi_ids: list[int] 12 ids: list[int]
13 13
14 14
15class MultiCLIPTokenizer(CLIPTokenizer): 15class MultiCLIPTokenizer(CLIPTokenizer):
@@ -30,20 +30,19 @@ class MultiCLIPTokenizer(CLIPTokenizer):
30 if isinstance(num_vectors, list): 30 if isinstance(num_vectors, list):
31 raise ValueError("Expected num_vectors to be int for single token") 31 raise ValueError("Expected num_vectors to be int for single token")
32 32
33 super().add_tokens(new_tokens) 33 if num_vectors < 1:
34 raise ValueError("Expected num_vectors to be >= 1")
34 35
35 if num_vectors == 1: 36 multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)]
36 multi_token = [new_tokens]
37 else:
38 multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)]
39 super().add_tokens(multi_token)
40 37
41 meta_id = super().convert_tokens_to_ids(new_tokens) 38 super().add_tokens(multi_token)
42 multi_ids = super().convert_tokens_to_ids(multi_token)
43 39
44 self.token_map[meta_id] = multi_ids 40 ids = super().convert_tokens_to_ids(multi_token)
41 meta_id = ids[0]
45 42
46 return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) 43 self.token_map[meta_id] = ids
44
45 return MultiCLIPTokenizerItem(new_tokens, meta_id, ids)
47 46
48 def encode(self, *args, vector_shuffle=True, **kwargs): 47 def encode(self, *args, vector_shuffle=True, **kwargs):
49 ids = super().encode(*args, **kwargs) 48 ids = super().encode(*args, **kwargs)
diff --git a/train_ti.py b/train_ti.py
index 3a5cfed..3776eb2 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -438,7 +438,7 @@ class Checkpointer(CheckpointerBase):
438 438
439 for new_token in self.new_tokens: 439 for new_token in self.new_tokens:
440 text_encoder.text_model.embeddings.save_embed( 440 text_encoder.text_model.embeddings.save_embed(
441 new_token.multi_ids, 441 new_token.ids,
442 checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") 442 checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin")
443 ) 443 )
444 444
@@ -537,8 +537,7 @@ def main():
537 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) 537 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
538 538
539 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): 539 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids):
540 embeddings.add_embed(new_token.placeholder_id) 540 embeddings.add_embed(new_token.ids, init_ids)
541 embeddings.add_embed(new_token.multi_ids, init_ids)
542 541
543 print(f"Added {len(new_tokens)} new tokens.") 542 print(f"Added {len(new_tokens)} new tokens.")
544 543