diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/ti.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/training/ti.py b/training/ti.py new file mode 100644 index 0000000..a5fd8e4 --- /dev/null +++ b/training/ti.py | |||
| @@ -0,0 +1,70 @@ | |||
| 1 | from typing import Optional | ||
| 2 | |||
| 3 | import torch | ||
| 4 | import torch.nn as nn | ||
| 5 | |||
| 6 | from transformers.models.clip import CLIPTextModel, CLIPTextConfig | ||
| 7 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
| 8 | |||
| 9 | |||
| 10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | ||
| 11 | text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) | ||
| 12 | text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight | ||
| 13 | text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight | ||
| 14 | text_encoder.text_model.embeddings = text_embeddings | ||
| 15 | return text_embeddings | ||
| 16 | |||
| 17 | |||
| 18 | class TrainableEmbeddings(CLIPTextEmbeddings): | ||
| 19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | ||
| 20 | super().__init__(config) | ||
| 21 | |||
| 22 | self.token_embedding.requires_grad_(False) | ||
| 23 | self.position_embedding.requires_grad_(False) | ||
| 24 | |||
| 25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | ||
| 26 | |||
| 27 | indices = torch.arange(self.token_embedding.num_embeddings) | ||
| 28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] | ||
| 29 | |||
| 30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) | ||
| 31 | |||
| 32 | def forward( | ||
| 33 | self, | ||
| 34 | input_ids: Optional[torch.LongTensor] = None, | ||
| 35 | position_ids: Optional[torch.LongTensor] = None, | ||
| 36 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| 37 | ) -> torch.Tensor: | ||
| 38 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
| 39 | |||
| 40 | if position_ids is None: | ||
| 41 | position_ids = self.position_ids[:, :seq_length] | ||
| 42 | |||
| 43 | if inputs_embeds is None: | ||
| 44 | mask = torch.isin( | ||
| 45 | input_ids, | ||
| 46 | self.train_indices.to(input_ids.device) | ||
| 47 | ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim) | ||
| 48 | |||
| 49 | trainable_input_ids = torch.tensor([ | ||
| 50 | [ | ||
| 51 | self.id_mapping[id] if id in self.id_mapping else 0 | ||
| 52 | for id in batch | ||
| 53 | ] | ||
| 54 | for batch in input_ids | ||
| 55 | ], device=input_ids.device) | ||
| 56 | |||
| 57 | inputs_embeds = torch.where( | ||
| 58 | mask, | ||
| 59 | self.trainable_embedding(trainable_input_ids), | ||
| 60 | self.token_embedding(input_ids) | ||
| 61 | ) | ||
| 62 | |||
| 63 | position_embeddings = self.position_embedding(position_ids) | ||
| 64 | embeddings = inputs_embeds + position_embeddings | ||
| 65 | |||
| 66 | return embeddings | ||
| 67 | |||
| 68 | @torch.no_grad() | ||
| 69 | def save(self): | ||
| 70 | self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data | ||
