from typing import Union, Optional from pathlib import Path import torch import torch.nn as nn from safetensors import safe_open from safetensors.torch import save_file from transformers import CLIPTextModel from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding: old_num_embeddings, old_embedding_dim = old_embedding.weight.shape if old_num_embeddings == new_num_embeddings: return old_embedding n = min(old_num_embeddings, new_num_embeddings) new_embedding = nn.Embedding( new_num_embeddings, old_embedding_dim, device=old_embedding.weight.device, dtype=old_embedding.weight.dtype ) new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] return new_embedding class OverlayLinear(nn.Module): def __init__(self, in_features, out_features, rank=4): super().__init__() if rank > min(in_features, out_features): raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") self.rank = rank self.down = nn.Linear(in_features, rank, bias=False) self.up = nn.Linear(rank, out_features, bias=False) self.reset() def reset(self): nn.init.normal_(self.down.weight, std=1 / self.rank) nn.init.zeros_(self.up.weight) def forward(self, hidden_states): orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) return up_hidden_states.to(orig_dtype) class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): super().__init__(config) self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) self.temp_token_ids = torch.tensor([], dtype=torch.long) def reset_overlay(self): self.overlay.reset() def resize(self, size: int): self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed( self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None, initializer_noise: float = 0.0, ): if isinstance(token_ids, int): token_ids = [token_ids] if initializer is None: initializer = token_ids if isinstance(initializer, int): initializer = [initializer] if isinstance(initializer, list): initializer = (initializer * len(token_ids))[:len(token_ids)] with torch.no_grad(): initializer = self.get_embed(initializer) initializer = initializer.to( device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype, ) if initializer_noise != 0: initializer += torch.randn_like(initializer) * initializer_noise token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: self.add_embed(input_ids, file.get_tensor("embed")) def save_embed(self, input_ids: list[int], filename: Path): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( self.token_embedding.weight.data[self.temp_token_ids] ) self.overlay.reset() self.temp_token_ids = torch.tensor([], dtype=torch.long) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) embeds = self.token_embedding(input_ids) mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) embeds[mask] += self.overlay(embeds[mask]) return embeds def forward( self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.get_embed(input_ids) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) text_encoder.text_model.embeddings = text_embeddings return text_embeddings