summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 17:12:12 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 17:12:12 +0100
commitb42e7fbc29fd8045a2b932eb8ae76587f51f7513 (patch)
tree85321e605cd8e183a0b9e05efcc4282921e667e0
parentSimplified multi-vector embedding code (diff)
downloadtextual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.tar.gz
textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.tar.bz2
textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.zip
Bugfixes for multi-vector token handling
-rw-r--r--common.py1
-rw-r--r--infer.py13
-rw-r--r--models/clip/embeddings.py27
-rw-r--r--models/clip/tokenizer.py39
4 files changed, 53 insertions, 27 deletions
diff --git a/common.py b/common.py
index 1e7f4b9..691be4e 100644
--- a/common.py
+++ b/common.py
@@ -30,7 +30,6 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC
30 if filename.is_file(): 30 if filename.is_file():
31 with safe_open(filename, framework="pt", device="cpu") as file: 31 with safe_open(filename, framework="pt", device="cpu") as file:
32 embed = file.get_tensor("embed") 32 embed = file.get_tensor("embed")
33
34 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) 33 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
35 embeddings.add_embed(added.ids, embed) 34 embeddings.add_embed(added.ids, embed)
36 35
diff --git a/infer.py b/infer.py
index 4bcaff5..f88245a 100644
--- a/infer.py
+++ b/infer.py
@@ -7,6 +7,8 @@ import cmd
7from pathlib import Path 7from pathlib import Path
8import torch 8import torch
9import json 9import json
10import traceback
11
10from PIL import Image 12from PIL import Image
11from slugify import slugify 13from slugify import slugify
12from diffusers import ( 14from diffusers import (
@@ -165,8 +167,8 @@ def run_parser(parser, defaults, input=None):
165 conf_args = argparse.Namespace() 167 conf_args = argparse.Namespace()
166 168
167 if args.config is not None: 169 if args.config is not None:
168 args = load_config(args.config) 170 conf_args = load_config(args.config)
169 args = parser.parse_args(namespace=argparse.Namespace(**args)) 171 conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0]
170 172
171 res = defaults.copy() 173 res = defaults.copy()
172 for dict in [vars(conf_args), vars(args)]: 174 for dict in [vars(conf_args), vars(args)]:
@@ -295,6 +297,7 @@ class CmdParse(cmd.Cmd):
295 elements = shlex.split(line) 297 elements = shlex.split(line)
296 except ValueError as e: 298 except ValueError as e:
297 print(str(e)) 299 print(str(e))
300 return
298 301
299 if elements[0] == 'q': 302 if elements[0] == 'q':
300 return True 303 return True
@@ -306,9 +309,11 @@ class CmdParse(cmd.Cmd):
306 print('Try again with a prompt!') 309 print('Try again with a prompt!')
307 return 310 return
308 except SystemExit: 311 except SystemExit:
312 traceback.print_exc()
309 self.parser.print_help() 313 self.parser.print_help()
314 return
310 except Exception as e: 315 except Exception as e:
311 print(e) 316 traceback.print_exc()
312 return 317 return
313 318
314 try: 319 try:
@@ -316,7 +321,7 @@ class CmdParse(cmd.Cmd):
316 except KeyboardInterrupt: 321 except KeyboardInterrupt:
317 print('Generation cancelled.') 322 print('Generation cancelled.')
318 except Exception as e: 323 except Exception as e:
319 print(e) 324 traceback.print_exc()
320 return 325 return
321 326
322 def do_exit(self, line): 327 def do_exit(self, line):
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index f82873e..91a575d 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -15,8 +15,12 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
15def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: 15def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding:
16 old_num_embeddings, old_embedding_dim = old_embedding.weight.size() 16 old_num_embeddings, old_embedding_dim = old_embedding.weight.size()
17 17
18 new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) 18 new_embedding = nn.Embedding(
19 new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) 19 old_num_embeddings + n,
20 old_embedding_dim,
21 device=old_embedding.weight.device,
22 dtype=old_embedding.weight.dtype
23 )
20 new_embedding.weight.data.zero_() 24 new_embedding.weight.data.zero_()
21 new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data 25 new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data
22 26
@@ -31,9 +35,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
31 self.position_embedding = embeddings.position_embedding 35 self.position_embedding = embeddings.position_embedding
32 36
33 self.temp_token_embedding = nn.Embedding( 37 self.temp_token_embedding = nn.Embedding(
34 self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) 38 self.token_embedding.num_embeddings,
39 self.token_embedding.embedding_dim,
40 device=self.token_embedding.weight.device,
41 dtype=self.token_embedding.weight.dtype
42 )
35 self.temp_token_embedding.weight.data.zero_() 43 self.temp_token_embedding.weight.data.zero_()
36 self.temp_token_ids = torch.tensor([]) 44 self.temp_token_ids = torch.tensor([], dtype=torch.long)
37 45
38 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 46 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
39 if isinstance(token_ids, int): 47 if isinstance(token_ids, int):
@@ -52,12 +60,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
52 self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) 60 self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids))
53 self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) 61 self.token_embedding = expand_embedding(self.token_embedding, len(token_ids))
54 62
55 token_ids = torch.tensor(token_ids) 63 token_ids = torch.tensor(token_ids, dtype=torch.long)
56 64
57 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 65 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
58 66
59 if initializer is not None: 67 if initializer is not None:
60 self.temp_token_embedding.weight.data[token_ids] = initializer 68 self.temp_token_embedding.weight.data[token_ids] = initializer.to(
69 dtype=self.temp_token_embedding.weight.dtype)
61 else: 70 else:
62 self.temp_token_embedding.weight.data[token_ids].zero_() 71 self.temp_token_embedding.weight.data[token_ids].zero_()
63 72
@@ -70,13 +79,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
70 79
71 def make_permanent(self): 80 def make_permanent(self):
72 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 81 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
73 self.temp_token_ids = torch.tensor([]) 82 self.temp_token_ids = torch.tensor([], dtype=torch.long)
74 83
75 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 84 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
76 if isinstance(input_ids, list): 85 if isinstance(input_ids, list):
77 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device) 86 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
78 87
79 mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) 88 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
80 89
81 embeds = self.token_embedding(input_ids) 90 embeds = self.token_embedding(input_ids)
82 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 91 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 7e08287..63566e0 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -44,20 +44,33 @@ class MultiCLIPTokenizer(CLIPTokenizer):
44 44
45 return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) 45 return MultiCLIPTokenizerItem(new_tokens, meta_id, ids)
46 46
47 def encode(self, *args, vector_shuffle=True, **kwargs): 47 def expand_id(self, id: int, vector_shuffle=True):
48 ids = super().encode(*args, **kwargs) 48 if id in self.token_map:
49 new_ids = [] 49 tokens = self.token_map[id]
50 50
51 for id in ids: 51 if vector_shuffle:
52 if id in self.token_map: 52 tokens = copy.copy(tokens)
53 tokens = self.token_map[id] 53 np.random.shuffle(tokens)
54 54
55 if vector_shuffle: 55 return tokens
56 tokens = copy.copy(tokens) 56 else:
57 np.random.shuffle(tokens) 57 return [id]
58 58
59 new_ids = new_ids + self.token_map[id] 59 def expand_ids(self, ids: list[int], vector_shuffle=True):
60 else: 60 return [
61 new_ids.append(id) 61 new_id
62 for id in ids
63 for new_id in self.expand_id(id, vector_shuffle)
64 ]
62 65
63 return new_ids 66 def _call_one(self, text, *args, vector_shuffle=True, **kwargs):
67 result = super()._call_one(text, *args, **kwargs)
68
69 is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list)
70
71 if is_batched:
72 result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids]
73 else:
74 result.input_ids = self.expand_ids(result.input_ids, vector_shuffle)
75
76 return result