From 86e908656bcd7585ec45cd930176800f759f146a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 17:33:00 +0200 Subject: Combined TI with embedding and LoRA --- models/clip/embeddings.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) (limited to 'models') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 88e0cc0..c9c788c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -66,12 +66,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.initializer_factor = config.initializer_factor self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) + self.temp_token_embedding = nn.Embedding( + self.token_embedding.num_embeddings, + self.token_embedding.embedding_dim, + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype + ) + self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() self.temp_token_ids = torch.tensor([], dtype=torch.long) def reset_overlay(self): self.overlay.reset() def resize(self, size: int): + self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed( @@ -106,6 +114,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) + self.temp_token_embedding.weight.data[token_ids] = initializer self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): @@ -116,9 +125,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( - self.token_embedding.weight.data[self.temp_token_ids] - ) + embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] + self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) self.overlay.reset() self.temp_token_ids = torch.tensor([], dtype=torch.long) @@ -127,8 +135,11 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) embeds = self.token_embedding(input_ids) + mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) - embeds[mask] += self.overlay(embeds[mask]) + + temp_embeds = self.temp_token_embedding(input_ids[mask]) + embeds[mask] = temp_embeds + self.overlay(temp_embeds) return embeds -- cgit v1.2.3-54-g00ecf