From 930e11e0a94453ae5ea8621fbc5f3c9a080149d1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 09:10:50 +0100 Subject: Fixed and simplified trainable embeddings code --- training/ti.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/training/ti.py b/training/ti.py index 2efd2f2..2e5139a 100644 --- a/training/ti.py +++ b/training/ti.py @@ -9,9 +9,15 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) - text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight - text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight + + text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding + text_embeddings.token_embedding.weight.requires_grad = False + + text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding + text_embeddings.position_embedding.weight.requires_grad = False + text_encoder.text_model.embeddings = text_embeddings + return text_embeddings @@ -19,15 +25,12 @@ class TrainableEmbeddings(CLIPTextEmbeddings): def __init__(self, config: CLIPTextConfig, new_ids: list[int]): super().__init__(config) - self.token_embedding.weight.requires_grad = False - self.position_embedding.weight.requires_grad = False - self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} - indices = torch.arange(self.token_embedding.num_embeddings) - self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] + self.train_indices = torch.tensor(new_ids) - self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) + self.trainable_embedding = nn.Embedding(len(new_ids), self.token_embedding.embedding_dim) + self.trainable_embedding.weight.data = self.token_embedding.weight.data[self.train_indices] self.trainable_embedding.weight.requires_grad = True def forward( @@ -42,10 +45,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings): position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: - mask = torch.isin( - input_ids, - self.train_indices.to(input_ids.device) - ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim) + mask = torch.isin(input_ids, self.train_indices.to(input_ids.device))[:, :, None] trainable_input_ids = torch.tensor([ [ -- cgit v1.2.3-70-g09d2