summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/ti.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/training/ti.py b/training/ti.py
new file mode 100644
index 0000000..a5fd8e4
--- /dev/null
+++ b/training/ti.py
@@ -0,0 +1,70 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6from transformers.models.clip import CLIPTextModel, CLIPTextConfig
7from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
8
9
10def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
11 text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids)
12 text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight
13 text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight
14 text_encoder.text_model.embeddings = text_embeddings
15 return text_embeddings
16
17
18class TrainableEmbeddings(CLIPTextEmbeddings):
19 def __init__(self, config: CLIPTextConfig, new_ids: list[int]):
20 super().__init__(config)
21
22 self.token_embedding.requires_grad_(False)
23 self.position_embedding.requires_grad_(False)
24
25 self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))}
26
27 indices = torch.arange(self.token_embedding.num_embeddings)
28 self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))]
29
30 self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices])
31
32 def forward(
33 self,
34 input_ids: Optional[torch.LongTensor] = None,
35 position_ids: Optional[torch.LongTensor] = None,
36 inputs_embeds: Optional[torch.FloatTensor] = None,
37 ) -> torch.Tensor:
38 seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
39
40 if position_ids is None:
41 position_ids = self.position_ids[:, :seq_length]
42
43 if inputs_embeds is None:
44 mask = torch.isin(
45 input_ids,
46 self.train_indices.to(input_ids.device)
47 ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim)
48
49 trainable_input_ids = torch.tensor([
50 [
51 self.id_mapping[id] if id in self.id_mapping else 0
52 for id in batch
53 ]
54 for batch in input_ids
55 ], device=input_ids.device)
56
57 inputs_embeds = torch.where(
58 mask,
59 self.trainable_embedding(trainable_input_ids),
60 self.token_embedding(input_ids)
61 )
62
63 position_embeddings = self.position_embedding(position_ids)
64 embeddings = inputs_embeds + position_embeddings
65
66 return embeddings
67
68 @torch.no_grad()
69 def save(self):
70 self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data