summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 10:19:38 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 10:19:38 +0100
commit6c64f769043c8212b1a5778e857af691a828798d (patch)
treefe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /models
parentUpdate (diff)
downloadtextual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz
textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2
textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip
Various cleanups
Diffstat (limited to 'models')
-rw-r--r--models/clip/embeddings.py5
-rw-r--r--models/clip/tokenizer.py9
2 files changed, 7 insertions, 7 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 1280ebd..fb639f1 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -53,6 +53,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 54
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
56 init_ratio = 1.0
57
56 if isinstance(token_ids, int): 58 if isinstance(token_ids, int):
57 token_ids = [token_ids] 59 token_ids = [token_ids]
58 60
@@ -63,6 +65,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
63 initializer = [initializer] 65 initializer = [initializer]
64 66
65 if isinstance(initializer, list): 67 if isinstance(initializer, list):
68 init_ratio = len(initializer) / len(token_ids)
66 initializer = (initializer * len(token_ids))[:len(token_ids)] 69 initializer = (initializer * len(token_ids))[:len(token_ids)]
67 70
68 with torch.no_grad(): 71 with torch.no_grad():
@@ -76,6 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
76 dtype=self.temp_token_embedding.weight.dtype, 79 dtype=self.temp_token_embedding.weight.dtype,
77 ) 80 )
78 81
82 return init_ratio
83
79 def load_embed(self, input_ids: list[int], filename: Path): 84 def load_embed(self, input_ids: list[int], filename: Path):
80 with safe_open(filename, framework="pt", device="cpu") as file: 85 with safe_open(filename, framework="pt", device="cpu") as file:
81 self.add_embed(input_ids, file.get_tensor("embed")) 86 self.add_embed(input_ids, file.get_tensor("embed"))
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 4e97ab5..034adf9 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -55,11 +55,6 @@ def shuffle_auto(tokens: list[int]):
55 return shuffle_all(tokens) 55 return shuffle_all(tokens)
56 56
57 57
58class MultiCLIPTokenizerItem(NamedTuple):
59 token: str
60 ids: list[int]
61
62
63class MultiCLIPTokenizer(CLIPTokenizer): 58class MultiCLIPTokenizer(CLIPTokenizer):
64 def __init__(self, *args, **kwargs): 59 def __init__(self, *args, **kwargs):
65 super().__init__(*args, **kwargs) 60 super().__init__(*args, **kwargs)
@@ -96,7 +91,7 @@ class MultiCLIPTokenizer(CLIPTokenizer):
96 self, 91 self,
97 new_tokens: Union[str, list[str]], 92 new_tokens: Union[str, list[str]],
98 num_vectors: Union[int, list[int]] = 1 93 num_vectors: Union[int, list[int]] = 1
99 ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]: 94 ) -> Union[list[int], list[list[int]]]:
100 if isinstance(new_tokens, list): 95 if isinstance(new_tokens, list):
101 if isinstance(num_vectors, int): 96 if isinstance(num_vectors, int):
102 num_vectors = [num_vectors] * len(new_tokens) 97 num_vectors = [num_vectors] * len(new_tokens)
@@ -119,7 +114,7 @@ class MultiCLIPTokenizer(CLIPTokenizer):
119 114
120 self.token_map[ids[0]] = ids 115 self.token_map[ids[0]] = ids
121 116
122 return MultiCLIPTokenizerItem(new_tokens, ids) 117 return ids
123 118
124 def expand_id(self, id: int): 119 def expand_id(self, id: int):
125 if id in self.token_map: 120 if id in self.token_map: