From b42e7fbc29fd8045a2b932eb8ae76587f51f7513 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 17:12:12 +0100 Subject: Bugfixes for multi-vector token handling --- common.py | 1 - infer.py | 13 +++++++++---- models/clip/embeddings.py | 27 ++++++++++++++++++--------- models/clip/tokenizer.py | 39 ++++++++++++++++++++++++++------------- 4 files changed, 53 insertions(+), 27 deletions(-) diff --git a/common.py b/common.py index 1e7f4b9..691be4e 100644 --- a/common.py +++ b/common.py @@ -30,7 +30,6 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC if filename.is_file(): with safe_open(filename, framework="pt", device="cpu") as file: embed = file.get_tensor("embed") - added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) embeddings.add_embed(added.ids, embed) diff --git a/infer.py b/infer.py index 4bcaff5..f88245a 100644 --- a/infer.py +++ b/infer.py @@ -7,6 +7,8 @@ import cmd from pathlib import Path import torch import json +import traceback + from PIL import Image from slugify import slugify from diffusers import ( @@ -165,8 +167,8 @@ def run_parser(parser, defaults, input=None): conf_args = argparse.Namespace() if args.config is not None: - args = load_config(args.config) - args = parser.parse_args(namespace=argparse.Namespace(**args)) + conf_args = load_config(args.config) + conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] res = defaults.copy() for dict in [vars(conf_args), vars(args)]: @@ -295,6 +297,7 @@ class CmdParse(cmd.Cmd): elements = shlex.split(line) except ValueError as e: print(str(e)) + return if elements[0] == 'q': return True @@ -306,9 +309,11 @@ class CmdParse(cmd.Cmd): print('Try again with a prompt!') return except SystemExit: + traceback.print_exc() self.parser.print_help() + return except Exception as e: - print(e) + traceback.print_exc() return try: @@ -316,7 +321,7 @@ class CmdParse(cmd.Cmd): except KeyboardInterrupt: print('Generation cancelled.') except Exception as e: - print(e) + traceback.print_exc() return def do_exit(self, line): diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f82873e..91a575d 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -15,8 +15,12 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: old_num_embeddings, old_embedding_dim = old_embedding.weight.size() - new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) - new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) + new_embedding = nn.Embedding( + old_num_embeddings + n, + old_embedding_dim, + device=old_embedding.weight.device, + dtype=old_embedding.weight.dtype + ) new_embedding.weight.data.zero_() new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data @@ -31,9 +35,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.position_embedding = embeddings.position_embedding self.temp_token_embedding = nn.Embedding( - self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) + self.token_embedding.num_embeddings, + self.token_embedding.embedding_dim, + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype + ) self.temp_token_embedding.weight.data.zero_() - self.temp_token_ids = torch.tensor([]) + self.temp_token_ids = torch.tensor([], dtype=torch.long) def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): if isinstance(token_ids, int): @@ -52,12 +60,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) - token_ids = torch.tensor(token_ids) + token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) if initializer is not None: - self.temp_token_embedding.weight.data[token_ids] = initializer + self.temp_token_embedding.weight.data[token_ids] = initializer.to( + dtype=self.temp_token_embedding.weight.dtype) else: self.temp_token_embedding.weight.data[token_ids].zero_() @@ -70,13 +79,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def make_permanent(self): self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] - self.temp_token_ids = torch.tensor([]) + self.temp_token_ids = torch.tensor([], dtype=torch.long) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): - input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device) + input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) - mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) + mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) embeds = self.token_embedding(input_ids) embeds[mask] = self.temp_token_embedding(input_ids)[mask] diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 7e08287..63566e0 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -44,20 +44,33 @@ class MultiCLIPTokenizer(CLIPTokenizer): return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) - def encode(self, *args, vector_shuffle=True, **kwargs): - ids = super().encode(*args, **kwargs) - new_ids = [] + def expand_id(self, id: int, vector_shuffle=True): + if id in self.token_map: + tokens = self.token_map[id] - for id in ids: - if id in self.token_map: - tokens = self.token_map[id] + if vector_shuffle: + tokens = copy.copy(tokens) + np.random.shuffle(tokens) - if vector_shuffle: - tokens = copy.copy(tokens) - np.random.shuffle(tokens) + return tokens + else: + return [id] - new_ids = new_ids + self.token_map[id] - else: - new_ids.append(id) + def expand_ids(self, ids: list[int], vector_shuffle=True): + return [ + new_id + for id in ids + for new_id in self.expand_id(id, vector_shuffle) + ] - return new_ids + def _call_one(self, text, *args, vector_shuffle=True, **kwargs): + result = super()._call_one(text, *args, **kwargs) + + is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) + + if is_batched: + result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] + else: + result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) + + return result -- cgit v1.2.3-70-g09d2