summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/ti.py25
1 files changed, 6 insertions, 19 deletions
diff --git a/training/ti.py b/training/ti.py
index 2e5139a..a5e407b 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -25,12 +25,10 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
25 def __init__(self, config: CLIPTextConfig, new_ids: list[int]): 25 def __init__(self, config: CLIPTextConfig, new_ids: list[int]):
26 super().__init__(config) 26 super().__init__(config)
27 27
28 self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))}
29
30 self.train_indices = torch.tensor(new_ids) 28 self.train_indices = torch.tensor(new_ids)
31 29
32 self.trainable_embedding = nn.Embedding(len(new_ids), self.token_embedding.embedding_dim) 30 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
33 self.trainable_embedding.weight.data = self.token_embedding.weight.data[self.train_indices] 31 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone()
34 self.trainable_embedding.weight.requires_grad = True 32 self.trainable_embedding.weight.requires_grad = True
35 33
36 def forward( 34 def forward(
@@ -39,27 +37,16 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
39 position_ids: Optional[torch.LongTensor] = None, 37 position_ids: Optional[torch.LongTensor] = None,
40 inputs_embeds: Optional[torch.FloatTensor] = None, 38 inputs_embeds: Optional[torch.FloatTensor] = None,
41 ) -> torch.Tensor: 39 ) -> torch.Tensor:
40 device = input_ids.device
42 seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] 41 seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
43 42
44 if position_ids is None: 43 if position_ids is None:
45 position_ids = self.position_ids[:, :seq_length] 44 position_ids = self.position_ids[:, :seq_length]
46 45
47 if inputs_embeds is None: 46 if inputs_embeds is None:
48 mask = torch.isin(input_ids, self.train_indices.to(input_ids.device))[:, :, None] 47 mask = torch.isin(input_ids, self.train_indices.to(device))
49 48 inputs_embeds = self.token_embedding(input_ids)
50 trainable_input_ids = torch.tensor([ 49 inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask]
51 [
52 self.id_mapping[id] if id in self.id_mapping else 0
53 for id in batch
54 ]
55 for batch in input_ids
56 ], device=input_ids.device)
57
58 inputs_embeds = torch.where(
59 mask,
60 self.trainable_embedding(trainable_input_ids),
61 self.token_embedding(input_ids)
62 )
63 50
64 position_embeddings = self.position_embedding(position_ids) 51 position_embeddings = self.position_embedding(position_ids)
65 embeddings = inputs_embeds + position_embeddings 52 embeddings = inputs_embeds + position_embeddings