diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
| commit | b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 (patch) | |
| tree | 24fd6d9f3a92ce9f5cccd5cdd914edfee665af71 /models | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.gz textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.bz2 textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.zip | |
Fixed accuracy calc, other improvements
Diffstat (limited to 'models')
| -rw-r--r-- | models/clip/tokenizer.py | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index a3e6e70..37d69a9 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -15,6 +15,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 15 | def __init__(self, *args, **kwargs): | 15 | def __init__(self, *args, **kwargs): |
| 16 | super().__init__(*args, **kwargs) | 16 | super().__init__(*args, **kwargs) |
| 17 | self.token_map: dict[int, list[int]] = {} | 17 | self.token_map: dict[int, list[int]] = {} |
| 18 | self.vector_shuffle = False | ||
| 19 | |||
| 20 | def set_use_vector_shuffle(self, enable: bool): | ||
| 21 | self.vector_shuffle = enable | ||
| 18 | 22 | ||
| 19 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | 23 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: |
| 20 | if isinstance(new_tokens, list): | 24 | if isinstance(new_tokens, list): |
| @@ -42,11 +46,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 42 | 46 | ||
| 43 | return MultiCLIPTokenizerItem(new_tokens, ids) | 47 | return MultiCLIPTokenizerItem(new_tokens, ids) |
| 44 | 48 | ||
| 45 | def expand_id(self, id: int, vector_shuffle=True): | 49 | def expand_id(self, id: int): |
| 46 | if id in self.token_map: | 50 | if id in self.token_map: |
| 47 | tokens = self.token_map[id] | 51 | tokens = self.token_map[id] |
| 48 | 52 | ||
| 49 | if vector_shuffle and len(tokens) > 2: | 53 | if self.vector_shuffle and len(tokens) > 2: |
| 50 | subtokens = tokens[1:-1] | 54 | subtokens = tokens[1:-1] |
| 51 | np.random.shuffle(subtokens) | 55 | np.random.shuffle(subtokens) |
| 52 | tokens = tokens[:1] + subtokens + tokens[-1:] | 56 | tokens = tokens[:1] + subtokens + tokens[-1:] |
| @@ -55,21 +59,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 55 | else: | 59 | else: |
| 56 | return [id] | 60 | return [id] |
| 57 | 61 | ||
| 58 | def expand_ids(self, ids: list[int], vector_shuffle=True): | 62 | def expand_ids(self, ids: list[int]): |
| 59 | return [ | 63 | return [ |
| 60 | new_id | 64 | new_id |
| 61 | for id in ids | 65 | for id in ids |
| 62 | for new_id in self.expand_id(id, vector_shuffle) | 66 | for new_id in self.expand_id(id) |
| 63 | ] | 67 | ] |
| 64 | 68 | ||
| 65 | def _call_one(self, text, *args, vector_shuffle=True, **kwargs): | 69 | def _call_one(self, text, *args, **kwargs): |
| 66 | result = super()._call_one(text, *args, **kwargs) | 70 | result = super()._call_one(text, *args, **kwargs) |
| 67 | 71 | ||
| 68 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) | 72 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) |
| 69 | 73 | ||
| 70 | if is_batched: | 74 | if is_batched: |
| 71 | result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] | 75 | result.input_ids = [self.expand_ids(batch) for batch in result.input_ids] |
| 72 | else: | 76 | else: |
| 73 | result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) | 77 | result.input_ids = self.expand_ids(result.input_ids) |
| 74 | 78 | ||
| 75 | return result | 79 | return result |
