diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/ti.py | 25 |
1 files changed, 6 insertions, 19 deletions
diff --git a/training/ti.py b/training/ti.py index 2e5139a..a5e407b 100644 --- a/training/ti.py +++ b/training/ti.py | |||
@@ -25,12 +25,10 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
25 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 25 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): |
26 | super().__init__(config) | 26 | super().__init__(config) |
27 | 27 | ||
28 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | ||
29 | |||
30 | self.train_indices = torch.tensor(new_ids) | 28 | self.train_indices = torch.tensor(new_ids) |
31 | 29 | ||
32 | self.trainable_embedding = nn.Embedding(len(new_ids), self.token_embedding.embedding_dim) | 30 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) |
33 | self.trainable_embedding.weight.data = self.token_embedding.weight.data[self.train_indices] | 31 | self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() |
34 | self.trainable_embedding.weight.requires_grad = True | 32 | self.trainable_embedding.weight.requires_grad = True |
35 | 33 | ||
36 | def forward( | 34 | def forward( |
@@ -39,27 +37,16 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
39 | position_ids: Optional[torch.LongTensor] = None, | 37 | position_ids: Optional[torch.LongTensor] = None, |
40 | inputs_embeds: Optional[torch.FloatTensor] = None, | 38 | inputs_embeds: Optional[torch.FloatTensor] = None, |
41 | ) -> torch.Tensor: | 39 | ) -> torch.Tensor: |
40 | device = input_ids.device | ||
42 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | 41 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] |
43 | 42 | ||
44 | if position_ids is None: | 43 | if position_ids is None: |
45 | position_ids = self.position_ids[:, :seq_length] | 44 | position_ids = self.position_ids[:, :seq_length] |
46 | 45 | ||
47 | if inputs_embeds is None: | 46 | if inputs_embeds is None: |
48 | mask = torch.isin(input_ids, self.train_indices.to(input_ids.device))[:, :, None] | 47 | mask = torch.isin(input_ids, self.train_indices.to(device)) |
49 | 48 | inputs_embeds = self.token_embedding(input_ids) | |
50 | trainable_input_ids = torch.tensor([ | 49 | inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask] |
51 | [ | ||
52 | self.id_mapping[id] if id in self.id_mapping else 0 | ||
53 | for id in batch | ||
54 | ] | ||
55 | for batch in input_ids | ||
56 | ], device=input_ids.device) | ||
57 | |||
58 | inputs_embeds = torch.where( | ||
59 | mask, | ||
60 | self.trainable_embedding(trainable_input_ids), | ||
61 | self.token_embedding(input_ids) | ||
62 | ) | ||
63 | 50 | ||
64 | position_embeddings = self.position_embedding(position_ids) | 51 | position_embeddings = self.position_embedding(position_ids) |
65 | embeddings = inputs_embeds + position_embeddings | 52 | embeddings = inputs_embeds + position_embeddings |