diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/ti.py | 25 |
1 files changed, 6 insertions, 19 deletions
diff --git a/training/ti.py b/training/ti.py index 2e5139a..a5e407b 100644 --- a/training/ti.py +++ b/training/ti.py | |||
| @@ -25,12 +25,10 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 25 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 25 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): |
| 26 | super().__init__(config) | 26 | super().__init__(config) |
| 27 | 27 | ||
| 28 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | ||
| 29 | |||
| 30 | self.train_indices = torch.tensor(new_ids) | 28 | self.train_indices = torch.tensor(new_ids) |
| 31 | 29 | ||
| 32 | self.trainable_embedding = nn.Embedding(len(new_ids), self.token_embedding.embedding_dim) | 30 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) |
| 33 | self.trainable_embedding.weight.data = self.token_embedding.weight.data[self.train_indices] | 31 | self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() |
| 34 | self.trainable_embedding.weight.requires_grad = True | 32 | self.trainable_embedding.weight.requires_grad = True |
| 35 | 33 | ||
| 36 | def forward( | 34 | def forward( |
| @@ -39,27 +37,16 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 39 | position_ids: Optional[torch.LongTensor] = None, | 37 | position_ids: Optional[torch.LongTensor] = None, |
| 40 | inputs_embeds: Optional[torch.FloatTensor] = None, | 38 | inputs_embeds: Optional[torch.FloatTensor] = None, |
| 41 | ) -> torch.Tensor: | 39 | ) -> torch.Tensor: |
| 40 | device = input_ids.device | ||
| 42 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | 41 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] |
| 43 | 42 | ||
| 44 | if position_ids is None: | 43 | if position_ids is None: |
| 45 | position_ids = self.position_ids[:, :seq_length] | 44 | position_ids = self.position_ids[:, :seq_length] |
| 46 | 45 | ||
| 47 | if inputs_embeds is None: | 46 | if inputs_embeds is None: |
| 48 | mask = torch.isin(input_ids, self.train_indices.to(input_ids.device))[:, :, None] | 47 | mask = torch.isin(input_ids, self.train_indices.to(device)) |
| 49 | 48 | inputs_embeds = self.token_embedding(input_ids) | |
| 50 | trainable_input_ids = torch.tensor([ | 49 | inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask] |
| 51 | [ | ||
| 52 | self.id_mapping[id] if id in self.id_mapping else 0 | ||
| 53 | for id in batch | ||
| 54 | ] | ||
| 55 | for batch in input_ids | ||
| 56 | ], device=input_ids.device) | ||
| 57 | |||
| 58 | inputs_embeds = torch.where( | ||
| 59 | mask, | ||
| 60 | self.trainable_embedding(trainable_input_ids), | ||
| 61 | self.token_embedding(input_ids) | ||
| 62 | ) | ||
| 63 | 50 | ||
| 64 | position_embeddings = self.position_embedding(position_ids) | 51 | position_embeddings = self.position_embedding(position_ids) |
| 65 | embeddings = inputs_embeds + position_embeddings | 52 | embeddings = inputs_embeds + position_embeddings |
