from typing import Optional import torch import torch.nn as nn from transformers.models.clip import CLIPTextModel, CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight text_encoder.text_model.embeddings = text_embeddings return text_embeddings class TrainableEmbeddings(CLIPTextEmbeddings): def __init__(self, config: CLIPTextConfig, new_ids: list[int]): super().__init__(config) self.token_embedding.weight.requires_grad = False self.position_embedding.weight.requires_grad = False self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} indices = torch.arange(self.token_embedding.num_embeddings) self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) self.trainable_embedding.weight.requires_grad = True def forward( self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: mask = torch.isin( input_ids, self.train_indices.to(input_ids.device) ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim) trainable_input_ids = torch.tensor([ [ self.id_mapping[id] if id in self.id_mapping else 0 for id in batch ] for batch in input_ids ], device=input_ids.device) inputs_embeds = torch.where( mask, self.trainable_embedding(trainable_input_ids), self.token_embedding(input_ids) ) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings