From 1c63552a20f34bccd461ac0dfa46405f853cbc7c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 27 Mar 2023 11:58:47 +0200 Subject: Fix TI --- models/clip/embeddings.py | 34 +++++++++------------------------- train_ti.py | 11 ++++++++++- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 2b315c4..2d60c28 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -38,24 +38,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - self.num_permanent_embeddings = self.token_embedding.num_embeddings - self.init_temp_embeddings() - def init_temp_embeddings(self): self.temp_token_embedding = nn.Embedding( - 0, + self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype ) + self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() self.temp_token_ids = torch.tensor([], dtype=torch.long) def resize(self, size: int): - self.temp_token_embedding = resize_embedding( - self.temp_token_embedding, - size - self.num_permanent_embeddings, - self.initializer_factor - ) + self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): @@ -75,15 +69,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): initializer = self.get_embed(initializer) initializer = initializer.to( - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype, + device=self.temp_token_embedding.weight.device, + dtype=self.temp_token_embedding.weight.dtype, ) token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) - mask = torch.nonzero(self.temp_token_ids == token_ids).squeeze(1) - self.temp_token_embedding.weight.data[mask] = initializer + self.temp_token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: @@ -94,25 +87,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def persist(self): self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] - self.num_permanent_embeddings = self.token_embedding.num_embeddings - self.init_temp_embeddings() + self.temp_token_ids = torch.tensor([], dtype=torch.long) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) - all_temp_token_ids = self.temp_token_ids.to(input_ids.device) - embeds = self.token_embedding(input_ids) - embeds_mask = torch.isin(input_ids, all_temp_token_ids) - temp_token_ids = input_ids[embeds_mask] - - temp_token_ids = temp_token_ids.unsqueeze(1) - all_temp_token_ids = all_temp_token_ids.unsqueeze(0) - temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() - - embeds[embeds_mask] = self.temp_token_embedding(temp_token_ids) + mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) + embeds[mask] = self.temp_token_embedding(input_ids)[mask] return embeds diff --git a/train_ti.py b/train_ti.py index ef39c38..9ae8d1b 100644 --- a/train_ti.py +++ b/train_ti.py @@ -155,7 +155,7 @@ def parse_args(): parser.add_argument( "--num_buckets", type=int, - default=2, + default=0, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( @@ -507,9 +507,18 @@ def parse_args(): if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if args.alias_tokens is None: + args.alias_tokens = [] + if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: raise ValueError("--alias_tokens must be a list with an even number of items") + args.alias_tokens += [ + item + for pair in zip(args.placeholder_tokens, args.initializer_tokens) + for item in pair + ] + if args.sequential: if isinstance(args.train_data_template, str): args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) -- cgit v1.2.3-54-g00ecf