from typing import Union, Optional from pathlib import Path import torch from safetensors import safe_open from safetensors.torch import save_file from transformers import CLIPTextModel from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings from models.sparse import SparseEmbedding class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def __init__( self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0, ): super().__init__(config) self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor self.token_embedding = SparseEmbedding( self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, alpha, dropout, ) self.token_embedding.weight = embeddings.token_embedding.weight def resize(self, size: int): self.token_embedding = self.token_embedding.new_resized( size, self.initializer_factor ) def add_embed( self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None, initializer_noise: float = 0.0, ): if isinstance(token_ids, int): token_ids = [token_ids] if initializer is None: initializer = token_ids if isinstance(initializer, int): initializer = [initializer] if isinstance(initializer, list): initializer = (initializer * len(token_ids))[: len(token_ids)] with torch.no_grad(): initializer = self.get_embed(initializer) initializer = initializer.to( device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype, ) if initializer_noise != 0: initializer += torch.randn_like(initializer) * initializer_noise token_ids = torch.tensor(token_ids, dtype=torch.long) self.token_embedding.mark_trainable(token_ids) self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: self.add_embed(input_ids, file.get_tensor("embed")) def save_embed(self, input_ids: list[int], filename: Path): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self, clear=False): self.token_embedding.persist(clear) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor( input_ids, device=self.token_embedding.weight.device, dtype=torch.long ) return self.token_embedding(input_ids) def patch_managed_embeddings( text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0 ) -> ManagedCLIPTextEmbeddings: if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): return text_encoder.text_model.embeddings text_embeddings = ManagedCLIPTextEmbeddings( text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout ) text_encoder.text_model.embeddings = text_embeddings return text_embeddings