diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/ti.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/training/ti.py b/training/ti.py new file mode 100644 index 0000000..a5fd8e4 --- /dev/null +++ b/training/ti.py | |||
@@ -0,0 +1,70 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | |||
6 | from transformers.models.clip import CLIPTextModel, CLIPTextConfig | ||
7 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
8 | |||
9 | |||
10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | ||
11 | text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) | ||
12 | text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight | ||
13 | text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight | ||
14 | text_encoder.text_model.embeddings = text_embeddings | ||
15 | return text_embeddings | ||
16 | |||
17 | |||
18 | class TrainableEmbeddings(CLIPTextEmbeddings): | ||
19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | ||
20 | super().__init__(config) | ||
21 | |||
22 | self.token_embedding.requires_grad_(False) | ||
23 | self.position_embedding.requires_grad_(False) | ||
24 | |||
25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | ||
26 | |||
27 | indices = torch.arange(self.token_embedding.num_embeddings) | ||
28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] | ||
29 | |||
30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) | ||
31 | |||
32 | def forward( | ||
33 | self, | ||
34 | input_ids: Optional[torch.LongTensor] = None, | ||
35 | position_ids: Optional[torch.LongTensor] = None, | ||
36 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
37 | ) -> torch.Tensor: | ||
38 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
39 | |||
40 | if position_ids is None: | ||
41 | position_ids = self.position_ids[:, :seq_length] | ||
42 | |||
43 | if inputs_embeds is None: | ||
44 | mask = torch.isin( | ||
45 | input_ids, | ||
46 | self.train_indices.to(input_ids.device) | ||
47 | ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim) | ||
48 | |||
49 | trainable_input_ids = torch.tensor([ | ||
50 | [ | ||
51 | self.id_mapping[id] if id in self.id_mapping else 0 | ||
52 | for id in batch | ||
53 | ] | ||
54 | for batch in input_ids | ||
55 | ], device=input_ids.device) | ||
56 | |||
57 | inputs_embeds = torch.where( | ||
58 | mask, | ||
59 | self.trainable_embedding(trainable_input_ids), | ||
60 | self.token_embedding(input_ids) | ||
61 | ) | ||
62 | |||
63 | position_embeddings = self.position_embedding(position_ids) | ||
64 | embeddings = inputs_embeds + position_embeddings | ||
65 | |||
66 | return embeddings | ||
67 | |||
68 | @torch.no_grad() | ||
69 | def save(self): | ||
70 | self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data | ||