From b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 11:36:00 +0100 Subject: Fixed accuracy calc, other improvements --- training/ti.py | 48 ------------------------------------------------ 1 file changed, 48 deletions(-) delete mode 100644 training/ti.py (limited to 'training') diff --git a/training/ti.py b/training/ti.py deleted file mode 100644 index 031fe48..0000000 --- a/training/ti.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn - -from transformers.models.clip import CLIPTextModel, CLIPTextConfig -from transformers.models.clip.modeling_clip import CLIPTextEmbeddings - - -def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): - text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids) - text_encoder.text_model.embeddings = text_embeddings - - -class TrainableEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): - super().__init__(config) - - self.token_embedding = embeddings.token_embedding - self.position_embedding = embeddings.position_embedding - - self.train_indices = torch.tensor(new_ids) - - self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) - self.trainable_embedding.weight.data.zero_() - self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices] - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> torch.Tensor: - device = input_ids.device - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - - if inputs_embeds is None: - mask = torch.isin(input_ids, self.train_indices.to(device)) - inputs_embeds = self.token_embedding(input_ids) - inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask] - - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - - return embeddings -- cgit v1.2.3-70-g09d2