diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-31 17:12:12 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-31 17:12:12 +0100 |
| commit | b42e7fbc29fd8045a2b932eb8ae76587f51f7513 (patch) | |
| tree | 85321e605cd8e183a0b9e05efcc4282921e667e0 | |
| parent | Simplified multi-vector embedding code (diff) | |
| download | textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.tar.gz textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.tar.bz2 textual-inversion-diff-b42e7fbc29fd8045a2b932eb8ae76587f51f7513.zip | |
Bugfixes for multi-vector token handling
| -rw-r--r-- | common.py | 1 | ||||
| -rw-r--r-- | infer.py | 13 | ||||
| -rw-r--r-- | models/clip/embeddings.py | 27 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 39 |
4 files changed, 53 insertions, 27 deletions
| @@ -30,7 +30,6 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC | |||
| 30 | if filename.is_file(): | 30 | if filename.is_file(): |
| 31 | with safe_open(filename, framework="pt", device="cpu") as file: | 31 | with safe_open(filename, framework="pt", device="cpu") as file: |
| 32 | embed = file.get_tensor("embed") | 32 | embed = file.get_tensor("embed") |
| 33 | |||
| 34 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | 33 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) |
| 35 | embeddings.add_embed(added.ids, embed) | 34 | embeddings.add_embed(added.ids, embed) |
| 36 | 35 | ||
| @@ -7,6 +7,8 @@ import cmd | |||
| 7 | from pathlib import Path | 7 | from pathlib import Path |
| 8 | import torch | 8 | import torch |
| 9 | import json | 9 | import json |
| 10 | import traceback | ||
| 11 | |||
| 10 | from PIL import Image | 12 | from PIL import Image |
| 11 | from slugify import slugify | 13 | from slugify import slugify |
| 12 | from diffusers import ( | 14 | from diffusers import ( |
| @@ -165,8 +167,8 @@ def run_parser(parser, defaults, input=None): | |||
| 165 | conf_args = argparse.Namespace() | 167 | conf_args = argparse.Namespace() |
| 166 | 168 | ||
| 167 | if args.config is not None: | 169 | if args.config is not None: |
| 168 | args = load_config(args.config) | 170 | conf_args = load_config(args.config) |
| 169 | args = parser.parse_args(namespace=argparse.Namespace(**args)) | 171 | conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] |
| 170 | 172 | ||
| 171 | res = defaults.copy() | 173 | res = defaults.copy() |
| 172 | for dict in [vars(conf_args), vars(args)]: | 174 | for dict in [vars(conf_args), vars(args)]: |
| @@ -295,6 +297,7 @@ class CmdParse(cmd.Cmd): | |||
| 295 | elements = shlex.split(line) | 297 | elements = shlex.split(line) |
| 296 | except ValueError as e: | 298 | except ValueError as e: |
| 297 | print(str(e)) | 299 | print(str(e)) |
| 300 | return | ||
| 298 | 301 | ||
| 299 | if elements[0] == 'q': | 302 | if elements[0] == 'q': |
| 300 | return True | 303 | return True |
| @@ -306,9 +309,11 @@ class CmdParse(cmd.Cmd): | |||
| 306 | print('Try again with a prompt!') | 309 | print('Try again with a prompt!') |
| 307 | return | 310 | return |
| 308 | except SystemExit: | 311 | except SystemExit: |
| 312 | traceback.print_exc() | ||
| 309 | self.parser.print_help() | 313 | self.parser.print_help() |
| 314 | return | ||
| 310 | except Exception as e: | 315 | except Exception as e: |
| 311 | print(e) | 316 | traceback.print_exc() |
| 312 | return | 317 | return |
| 313 | 318 | ||
| 314 | try: | 319 | try: |
| @@ -316,7 +321,7 @@ class CmdParse(cmd.Cmd): | |||
| 316 | except KeyboardInterrupt: | 321 | except KeyboardInterrupt: |
| 317 | print('Generation cancelled.') | 322 | print('Generation cancelled.') |
| 318 | except Exception as e: | 323 | except Exception as e: |
| 319 | print(e) | 324 | traceback.print_exc() |
| 320 | return | 325 | return |
| 321 | 326 | ||
| 322 | def do_exit(self, line): | 327 | def do_exit(self, line): |
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f82873e..91a575d 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -15,8 +15,12 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | |||
| 15 | def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: | 15 | def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: |
| 16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() | 16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() |
| 17 | 17 | ||
| 18 | new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) | 18 | new_embedding = nn.Embedding( |
| 19 | new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) | 19 | old_num_embeddings + n, |
| 20 | old_embedding_dim, | ||
| 21 | device=old_embedding.weight.device, | ||
| 22 | dtype=old_embedding.weight.dtype | ||
| 23 | ) | ||
| 20 | new_embedding.weight.data.zero_() | 24 | new_embedding.weight.data.zero_() |
| 21 | new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data | 25 | new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data |
| 22 | 26 | ||
| @@ -31,9 +35,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 31 | self.position_embedding = embeddings.position_embedding | 35 | self.position_embedding = embeddings.position_embedding |
| 32 | 36 | ||
| 33 | self.temp_token_embedding = nn.Embedding( | 37 | self.temp_token_embedding = nn.Embedding( |
| 34 | self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | 38 | self.token_embedding.num_embeddings, |
| 39 | self.token_embedding.embedding_dim, | ||
| 40 | device=self.token_embedding.weight.device, | ||
| 41 | dtype=self.token_embedding.weight.dtype | ||
| 42 | ) | ||
| 35 | self.temp_token_embedding.weight.data.zero_() | 43 | self.temp_token_embedding.weight.data.zero_() |
| 36 | self.temp_token_ids = torch.tensor([]) | 44 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 37 | 45 | ||
| 38 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | 46 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): |
| 39 | if isinstance(token_ids, int): | 47 | if isinstance(token_ids, int): |
| @@ -52,12 +60,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 52 | self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) | 60 | self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) |
| 53 | self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) | 61 | self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) |
| 54 | 62 | ||
| 55 | token_ids = torch.tensor(token_ids) | 63 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 56 | 64 | ||
| 57 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 65 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 58 | 66 | ||
| 59 | if initializer is not None: | 67 | if initializer is not None: |
| 60 | self.temp_token_embedding.weight.data[token_ids] = initializer | 68 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( |
| 69 | dtype=self.temp_token_embedding.weight.dtype) | ||
| 61 | else: | 70 | else: |
| 62 | self.temp_token_embedding.weight.data[token_ids].zero_() | 71 | self.temp_token_embedding.weight.data[token_ids].zero_() |
| 63 | 72 | ||
| @@ -70,13 +79,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 70 | 79 | ||
| 71 | def make_permanent(self): | 80 | def make_permanent(self): |
| 72 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 81 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
| 73 | self.temp_token_ids = torch.tensor([]) | 82 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 74 | 83 | ||
| 75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 84 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 76 | if isinstance(input_ids, list): | 85 | if isinstance(input_ids, list): |
| 77 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device) | 86 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 78 | 87 | ||
| 79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) | 88 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
| 80 | 89 | ||
| 81 | embeds = self.token_embedding(input_ids) | 90 | embeds = self.token_embedding(input_ids) |
| 82 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 91 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 7e08287..63566e0 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -44,20 +44,33 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 44 | 44 | ||
| 45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) | 45 | return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) |
| 46 | 46 | ||
| 47 | def encode(self, *args, vector_shuffle=True, **kwargs): | 47 | def expand_id(self, id: int, vector_shuffle=True): |
| 48 | ids = super().encode(*args, **kwargs) | 48 | if id in self.token_map: |
| 49 | new_ids = [] | 49 | tokens = self.token_map[id] |
| 50 | 50 | ||
| 51 | for id in ids: | 51 | if vector_shuffle: |
| 52 | if id in self.token_map: | 52 | tokens = copy.copy(tokens) |
| 53 | tokens = self.token_map[id] | 53 | np.random.shuffle(tokens) |
| 54 | 54 | ||
| 55 | if vector_shuffle: | 55 | return tokens |
| 56 | tokens = copy.copy(tokens) | 56 | else: |
| 57 | np.random.shuffle(tokens) | 57 | return [id] |
| 58 | 58 | ||
| 59 | new_ids = new_ids + self.token_map[id] | 59 | def expand_ids(self, ids: list[int], vector_shuffle=True): |
| 60 | else: | 60 | return [ |
| 61 | new_ids.append(id) | 61 | new_id |
| 62 | for id in ids | ||
| 63 | for new_id in self.expand_id(id, vector_shuffle) | ||
| 64 | ] | ||
| 62 | 65 | ||
| 63 | return new_ids | 66 | def _call_one(self, text, *args, vector_shuffle=True, **kwargs): |
| 67 | result = super()._call_one(text, *args, **kwargs) | ||
| 68 | |||
| 69 | is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) | ||
| 70 | |||
| 71 | if is_batched: | ||
| 72 | result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] | ||
| 73 | else: | ||
| 74 | result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) | ||
| 75 | |||
| 76 | return result | ||
