From 208e48134e324e934ad964bdc61880cc923f4c0d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 22:13:55 +0200 Subject: Revert --- models/clip/embeddings.py | 42 ++++-------------------------------------- 1 file changed, 4 insertions(+), 38 deletions(-) (limited to 'models/clip') diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index c9c788c..1e21965 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -31,41 +31,15 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi return new_embedding -class OverlayLinear(nn.Module): - def __init__(self, in_features, out_features, rank=4): - super().__init__() - - if rank > min(in_features, out_features): - raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") - - self.rank = rank - self.down = nn.Linear(in_features, rank, bias=False) - self.up = nn.Linear(rank, out_features, bias=False) - self.reset() - - def reset(self): - nn.init.normal_(self.down.weight, std=1 / self.rank) - nn.init.zeros_(self.up.weight) - - def forward(self, hidden_states): - orig_dtype = hidden_states.dtype - dtype = self.down.weight.dtype - - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) - - return up_hidden_states.to(orig_dtype) - - class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4): super().__init__(config) self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor + self.alpha = alpha - self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) self.temp_token_embedding = nn.Embedding( self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, @@ -75,9 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() self.temp_token_ids = torch.tensor([], dtype=torch.long) - def reset_overlay(self): - self.overlay.reset() - def resize(self, size: int): self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) @@ -125,9 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] - self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) - self.overlay.reset() + self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] self.temp_token_ids = torch.tensor([], dtype=torch.long) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): @@ -135,11 +104,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) embeds = self.token_embedding(input_ids) - mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) - - temp_embeds = self.temp_token_embedding(input_ids[mask]) - embeds[mask] = temp_embeds + self.overlay(temp_embeds) + embeds[mask] = self.temp_token_embedding(input_ids[mask]) return embeds -- cgit v1.2.3-70-g09d2