summaryrefslogtreecommitdiffstats
path: root/models/clip
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-01 19:19:52 +0100
committerVolpeon <git@volpeon.ink>2023-01-01 19:19:52 +0100
commitadc52fb8821a496bc8d78235bf10466b39df03e0 (patch)
tree8a6337a6ac10cbe76c55514ab559c647e69fb1aa /models/clip
parentFixed accuracy calc, other improvements (diff)
downloadtextual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.gz
textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.bz2
textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.zip
Updates
Diffstat (limited to 'models/clip')
-rw-r--r--models/clip/embeddings.py11
-rw-r--r--models/clip/tokenizer.py76
2 files changed, 71 insertions, 16 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index f90e7c2..8602142 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -120,3 +120,14 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe
120 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) 120 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
121 text_encoder.text_model.embeddings = text_embeddings 121 text_encoder.text_model.embeddings = text_embeddings
122 return text_embeddings 122 return text_embeddings
123
124
125def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings:
126 text_encoder.text_model.embeddings.make_permanent()
127
128 text_embeddings = CLIPTextEmbeddings(text_encoder.config)
129 text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding
130 text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding
131 text_encoder.text_model.embeddings = text_embeddings
132
133 return text_embeddings
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 37d69a9..ed9774e 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -1,11 +1,54 @@
1import copy 1import copy
2from typing import NamedTuple, Union 2from typing import NamedTuple, Union, Literal
3 3
4import numpy as np 4import numpy as np
5 5
6from transformers import CLIPTokenizer 6from transformers import CLIPTokenizer
7 7
8 8
9def shuffle_all(tokens: list[int]):
10 if len(tokens) >= 2:
11 tokens = copy.copy(tokens)
12 np.random.shuffle(tokens)
13 return tokens
14
15
16def shuffle_leading(tokens: list[int]):
17 if len(tokens) >= 3:
18 subtokens = tokens[:-1]
19 np.random.shuffle(subtokens)
20 tokens = subtokens + tokens[-1:]
21 return tokens
22
23
24def shuffle_trailing(tokens: list[int]):
25 if len(tokens) >= 3:
26 subtokens = tokens[1:]
27 np.random.shuffle(subtokens)
28 tokens = tokens[:1] + subtokens
29 return tokens
30
31
32def shuffle_between(tokens: list[int]):
33 if len(tokens) >= 4:
34 subtokens = tokens[1:-1]
35 np.random.shuffle(subtokens)
36 tokens = tokens[:1] + subtokens + tokens[-1:]
37 return tokens
38
39
40def shuffle_none(tokens: list[int]):
41 return tokens
42
43
44def shuffle_auto(tokens: list[int]):
45 if len(tokens) >= 4:
46 return shuffle_between(tokens)
47 if len(tokens) >= 3:
48 return shuffle_trailing(tokens)
49 return shuffle_all(tokens)
50
51
9class MultiCLIPTokenizerItem(NamedTuple): 52class MultiCLIPTokenizerItem(NamedTuple):
10 token: str 53 token: str
11 ids: list[int] 54 ids: list[int]
@@ -15,10 +58,21 @@ class MultiCLIPTokenizer(CLIPTokenizer):
15 def __init__(self, *args, **kwargs): 58 def __init__(self, *args, **kwargs):
16 super().__init__(*args, **kwargs) 59 super().__init__(*args, **kwargs)
17 self.token_map: dict[int, list[int]] = {} 60 self.token_map: dict[int, list[int]] = {}
18 self.vector_shuffle = False 61 self.vector_shuffle = shuffle_none
19 62
20 def set_use_vector_shuffle(self, enable: bool): 63 def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]):
21 self.vector_shuffle = enable 64 if algorithm == "leading":
65 self.vector_shuffle = shuffle_leading
66 elif algorithm == "trailing":
67 self.vector_shuffle = shuffle_trailing
68 elif algorithm == "between":
69 self.vector_shuffle = shuffle_between
70 elif algorithm == "auto":
71 self.vector_shuffle = shuffle_auto
72 elif algorithm == True or algorithm == "all":
73 self.vector_shuffle = shuffle_all
74 else:
75 self.vector_shuffle = shuffle_none
22 76
23 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: 77 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem:
24 if isinstance(new_tokens, list): 78 if isinstance(new_tokens, list):
@@ -47,17 +101,7 @@ class MultiCLIPTokenizer(CLIPTokenizer):
47 return MultiCLIPTokenizerItem(new_tokens, ids) 101 return MultiCLIPTokenizerItem(new_tokens, ids)
48 102
49 def expand_id(self, id: int): 103 def expand_id(self, id: int):
50 if id in self.token_map: 104 return self.vector_shuffle(self.token_map[id]) if id in self.token_map else [id]
51 tokens = self.token_map[id]
52
53 if self.vector_shuffle and len(tokens) > 2:
54 subtokens = tokens[1:-1]
55 np.random.shuffle(subtokens)
56 tokens = tokens[:1] + subtokens + tokens[-1:]
57
58 return tokens
59 else:
60 return [id]
61 105
62 def expand_ids(self, ids: list[int]): 106 def expand_ids(self, ids: list[int]):
63 return [ 107 return [