From 56edf85c8b80d49c998bcf26392cce50d552137a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 23:09:41 +0100 Subject: Update --- common.py | 24 ++++++++++++++++-------- models/clip/embeddings.py | 30 ++++++++++++++++-------------- models/clip/tokenizer.py | 6 ++---- train_ti.py | 1 + 4 files changed, 35 insertions(+), 26 deletions(-) diff --git a/common.py b/common.py index 691be4e..0887197 100644 --- a/common.py +++ b/common.py @@ -24,13 +24,21 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC return [] filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] - tokens = [filename.stem for filename in filenames] - for filename in embeddings_dir.iterdir(): - if filename.is_file(): - with safe_open(filename, framework="pt", device="cpu") as file: - embed = file.get_tensor("embed") - added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) - embeddings.add_embed(added.ids, embed) + new_tokens = [] + new_embeds = [] - return tokens + for filename in filenames: + with safe_open(filename, framework="pt", device="cpu") as file: + embed = file.get_tensor("embed") + + added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) + new_tokens.append(added) + new_embeds.append(embed) + + embeddings.resize(len(tokenizer)) + + for (new_token, embeds) in zip(new_tokens, new_embeds): + embeddings.add_embed(new_token.ids, embeds) + + return new_tokens diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 91a575d..cab1515 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -12,18 +12,22 @@ from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings -def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: +def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding: old_num_embeddings, old_embedding_dim = old_embedding.weight.size() + if old_num_embeddings == new_num_embeddings: + return old_embedding + + n = min(old_num_embeddings, new_num_embeddings) + new_embedding = nn.Embedding( - old_num_embeddings + n, + new_num_embeddings, old_embedding_dim, device=old_embedding.weight.device, dtype=old_embedding.weight.dtype ) - new_embedding.weight.data.zero_() - new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data - + new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) + new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] return new_embedding @@ -40,9 +44,13 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype ) - self.temp_token_embedding.weight.data.zero_() + self.temp_token_embedding.weight.data.normal_(mean=0.0, std=config.initializer_factor * 0.02) self.temp_token_ids = torch.tensor([], dtype=torch.long) + def resize(self, size: int): + self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.config.initializer_factor) + self.token_embedding = resize_embedding(self.token_embedding, size, self.config.initializer_factor) + def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): if isinstance(token_ids, int): token_ids = [token_ids] @@ -55,20 +63,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): initializer = (initializer * len(token_ids))[:len(token_ids)] with torch.no_grad(): - initializer = self.get_embed(initializer) - - self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) - self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) + initializer = self.get_embed(initializer).to(dtype=self.temp_token_embedding.weight.dtype) token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) if initializer is not None: - self.temp_token_embedding.weight.data[token_ids] = initializer.to( - dtype=self.temp_token_embedding.weight.dtype) - else: - self.temp_token_embedding.weight.data[token_ids].zero_() + self.temp_token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 63566e0..fbfe790 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -8,7 +8,6 @@ from transformers import CLIPTokenizer class MultiCLIPTokenizerItem(NamedTuple): token: str - meta_id: int ids: list[int] @@ -38,11 +37,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): super().add_tokens(multi_token) ids = super().convert_tokens_to_ids(multi_token) - meta_id = ids[0] - self.token_map[meta_id] = ids + self.token_map[ids[0]] = ids - return MultiCLIPTokenizerItem(new_tokens, meta_id, ids) + return MultiCLIPTokenizerItem(new_tokens, ids) def expand_id(self, id: int, vector_shuffle=True): if id in self.token_map: diff --git a/train_ti.py b/train_ti.py index 3776eb2..19348e5 100644 --- a/train_ti.py +++ b/train_ti.py @@ -535,6 +535,7 @@ def main(): ] new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) + embeddings.resize(len(tokenizer)) for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): embeddings.add_embed(new_token.ids, init_ids) -- cgit v1.2.3-70-g09d2