summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-01 11:36:00 +0100
committerVolpeon <git@volpeon.ink>2023-01-01 11:36:00 +0100
commitb7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 (patch)
tree24fd6d9f3a92ce9f5cccd5cdd914edfee665af71 /models
parentFix (diff)
downloadtextual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.gz
textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.bz2
textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.zip
Fixed accuracy calc, other improvements
Diffstat (limited to 'models')
-rw-r--r--models/clip/tokenizer.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index a3e6e70..37d69a9 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -15,6 +15,10 @@ class MultiCLIPTokenizer(CLIPTokenizer):
15 def __init__(self, *args, **kwargs): 15 def __init__(self, *args, **kwargs):
16 super().__init__(*args, **kwargs) 16 super().__init__(*args, **kwargs)
17 self.token_map: dict[int, list[int]] = {} 17 self.token_map: dict[int, list[int]] = {}
18 self.vector_shuffle = False
19
20 def set_use_vector_shuffle(self, enable: bool):
21 self.vector_shuffle = enable
18 22
19 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: 23 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem:
20 if isinstance(new_tokens, list): 24 if isinstance(new_tokens, list):
@@ -42,11 +46,11 @@ class MultiCLIPTokenizer(CLIPTokenizer):
42 46
43 return MultiCLIPTokenizerItem(new_tokens, ids) 47 return MultiCLIPTokenizerItem(new_tokens, ids)
44 48
45 def expand_id(self, id: int, vector_shuffle=True): 49 def expand_id(self, id: int):
46 if id in self.token_map: 50 if id in self.token_map:
47 tokens = self.token_map[id] 51 tokens = self.token_map[id]
48 52
49 if vector_shuffle and len(tokens) > 2: 53 if self.vector_shuffle and len(tokens) > 2:
50 subtokens = tokens[1:-1] 54 subtokens = tokens[1:-1]
51 np.random.shuffle(subtokens) 55 np.random.shuffle(subtokens)
52 tokens = tokens[:1] + subtokens + tokens[-1:] 56 tokens = tokens[:1] + subtokens + tokens[-1:]
@@ -55,21 +59,21 @@ class MultiCLIPTokenizer(CLIPTokenizer):
55 else: 59 else:
56 return [id] 60 return [id]
57 61
58 def expand_ids(self, ids: list[int], vector_shuffle=True): 62 def expand_ids(self, ids: list[int]):
59 return [ 63 return [
60 new_id 64 new_id
61 for id in ids 65 for id in ids
62 for new_id in self.expand_id(id, vector_shuffle) 66 for new_id in self.expand_id(id)
63 ] 67 ]
64 68
65 def _call_one(self, text, *args, vector_shuffle=True, **kwargs): 69 def _call_one(self, text, *args, **kwargs):
66 result = super()._call_one(text, *args, **kwargs) 70 result = super()._call_one(text, *args, **kwargs)
67 71
68 is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) 72 is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list)
69 73
70 if is_batched: 74 if is_batched:
71 result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] 75 result.input_ids = [self.expand_ids(batch) for batch in result.input_ids]
72 else: 76 else:
73 result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) 77 result.input_ids = self.expand_ids(result.input_ids)
74 78
75 return result 79 return result