From c96073646bbb638d7d78fdd7d9fdeed08d1454b5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 16:30:36 +0200 Subject: Experimental: TI via LoRA --- models/clip/embeddings.py | 53 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 15 deletions(-) (limited to 'models/clip/embeddings.py') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9abd1bb..88e0cc0 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -31,25 +31,47 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi return new_embedding +class OverlayLinear(nn.Module): + def __init__(self, in_features, out_features, rank=4): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.rank = rank + self.down = nn.Linear(in_features, rank, bias=False) + self.up = nn.Linear(rank, out_features, bias=False) + self.reset() + + def reset(self): + nn.init.normal_(self.down.weight, std=1 / self.rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + return up_hidden_states.to(orig_dtype) + + class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): super().__init__(config) self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - self.temp_token_embedding = nn.Embedding( - self.token_embedding.num_embeddings, - self.token_embedding.embedding_dim, - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype - ) - self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() + self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) self.temp_token_ids = torch.tensor([], dtype=torch.long) + def reset_overlay(self): + self.overlay.reset() + def resize(self, size: int): - self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed( @@ -74,8 +96,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): initializer = self.get_embed(initializer) initializer = initializer.to( - device=self.temp_token_embedding.weight.device, - dtype=self.temp_token_embedding.weight.dtype, + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype, ) if initializer_noise != 0: @@ -84,7 +106,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) - self.temp_token_embedding.weight.data[token_ids] = initializer self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): @@ -95,7 +116,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] + self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( + self.token_embedding.weight.data[self.temp_token_ids] + ) + self.overlay.reset() self.temp_token_ids = torch.tensor([], dtype=torch.long) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): @@ -103,9 +127,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) embeds = self.token_embedding(input_ids) - mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) - embeds[mask] = self.temp_token_embedding(input_ids)[mask] + embeds[mask] += self.overlay(embeds[mask]) return embeds -- cgit v1.2.3-70-g09d2