From 8161559d3529f164939ba9a41dfc9c6dfc8c4be2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 12:09:29 +0100 Subject: Simplified trainable embedding code again --- training/ti.py | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) (limited to 'training') diff --git a/training/ti.py b/training/ti.py index 2e5139a..a5e407b 100644 --- a/training/ti.py +++ b/training/ti.py @@ -25,12 +25,10 @@ class TrainableEmbeddings(CLIPTextEmbeddings): def __init__(self, config: CLIPTextConfig, new_ids: list[int]): super().__init__(config) - self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} - self.train_indices = torch.tensor(new_ids) - self.trainable_embedding = nn.Embedding(len(new_ids), self.token_embedding.embedding_dim) - self.trainable_embedding.weight.data = self.token_embedding.weight.data[self.train_indices] + self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) + self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() self.trainable_embedding.weight.requires_grad = True def forward( @@ -39,27 +37,16 @@ class TrainableEmbeddings(CLIPTextEmbeddings): position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: + device = input_ids.device seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: - mask = torch.isin(input_ids, self.train_indices.to(input_ids.device))[:, :, None] - - trainable_input_ids = torch.tensor([ - [ - self.id_mapping[id] if id in self.id_mapping else 0 - for id in batch - ] - for batch in input_ids - ], device=input_ids.device) - - inputs_embeds = torch.where( - mask, - self.trainable_embedding(trainable_input_ids), - self.token_embedding(input_ids) - ) + mask = torch.isin(input_ids, self.train_indices.to(device)) + inputs_embeds = self.token_embedding(input_ids) + inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask] position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings -- cgit v1.2.3-54-g00ecf