summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 14:07:44 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 14:07:44 +0100
commitdc463a6b8ef120b7a0643569b66f9109ed38c652 (patch)
treeae742a988b8541009a980c8b2f719696f9d7df27 /models
parentFixes (diff)
downloadtextual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.gz
textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.tar.bz2
textual-inversion-diff-dc463a6b8ef120b7a0643569b66f9109ed38c652.zip
Simplified multi-vector embedding code
Diffstat (limited to 'models')
-rw-r--r--models/clip/tokenizer.py23
1 files changed, 11 insertions, 12 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 78871db..7e08287 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -8,8 +8,8 @@ from transformers import CLIPTokenizer
8 8
9class MultiCLIPTokenizerItem(NamedTuple): 9class MultiCLIPTokenizerItem(NamedTuple):
10 token: str 10 token: str
11 placeholder_id: int 11 meta_id: int
12 multi_ids: list[int] 12 ids: list[int]
13 13
14 14
15class MultiCLIPTokenizer(CLIPTokenizer): 15class MultiCLIPTokenizer(CLIPTokenizer):
@@ -30,20 +30,19 @@ class MultiCLIPTokenizer(CLIPTokenizer):
30 if isinstance(num_vectors, list): 30 if isinstance(num_vectors, list):
31 raise ValueError("Expected num_vectors to be int for single token") 31 raise ValueError("Expected num_vectors to be int for single token")
32 32
33 super().add_tokens(new_tokens) 33 if num_vectors < 1:
34 raise ValueError("Expected num_vectors to be >= 1")
34 35
35 if num_vectors == 1: 36 multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)]
36 multi_token = [new_tokens]
37 else:
38 multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)]
39 super().add_tokens(multi_token)
40 37
41 meta_id = super().convert_tokens_to_ids(new_tokens) 38 super().add_tokens(multi_token)
42 multi_ids = super().convert_tokens_to_ids(multi_token)
43 39
44 self.token_map[meta_id] = multi_ids 40 ids = super().convert_tokens_to_ids(multi_token)
41 meta_id = ids[0]
45 42
46 return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) 43 self.token_map[meta_id] = ids
44
45 return MultiCLIPTokenizerItem(new_tokens, meta_id, ids)
47 46
48 def encode(self, *args, vector_shuffle=True, **kwargs): 47 def encode(self, *args, vector_shuffle=True, **kwargs):
49 ids = super().encode(*args, **kwargs) 48 ids = super().encode(*args, **kwargs)