diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/ti.py | 48 |
1 files changed, 0 insertions, 48 deletions
diff --git a/training/ti.py b/training/ti.py deleted file mode 100644 index 031fe48..0000000 --- a/training/ti.py +++ /dev/null | |||
@@ -1,48 +0,0 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | |||
6 | from transformers.models.clip import CLIPTextModel, CLIPTextConfig | ||
7 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
8 | |||
9 | |||
10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | ||
11 | text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids) | ||
12 | text_encoder.text_model.embeddings = text_embeddings | ||
13 | |||
14 | |||
15 | class TrainableEmbeddings(CLIPTextEmbeddings): | ||
16 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): | ||
17 | super().__init__(config) | ||
18 | |||
19 | self.token_embedding = embeddings.token_embedding | ||
20 | self.position_embedding = embeddings.position_embedding | ||
21 | |||
22 | self.train_indices = torch.tensor(new_ids) | ||
23 | |||
24 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | ||
25 | self.trainable_embedding.weight.data.zero_() | ||
26 | self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices] | ||
27 | |||
28 | def forward( | ||
29 | self, | ||
30 | input_ids: Optional[torch.LongTensor] = None, | ||
31 | position_ids: Optional[torch.LongTensor] = None, | ||
32 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
33 | ) -> torch.Tensor: | ||
34 | device = input_ids.device | ||
35 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
36 | |||
37 | if position_ids is None: | ||
38 | position_ids = self.position_ids[:, :seq_length] | ||
39 | |||
40 | if inputs_embeds is None: | ||
41 | mask = torch.isin(input_ids, self.train_indices.to(device)) | ||
42 | inputs_embeds = self.token_embedding(input_ids) | ||
43 | inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask] | ||
44 | |||
45 | position_embeddings = self.position_embedding(position_ids) | ||
46 | embeddings = inputs_embeds + position_embeddings | ||
47 | |||
48 | return embeddings | ||