summaryrefslogtreecommitdiffstats
path: root/training/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/ti.py')
-rw-r--r--training/ti.py48
1 files changed, 0 insertions, 48 deletions
diff --git a/training/ti.py b/training/ti.py
deleted file mode 100644
index 031fe48..0000000
--- a/training/ti.py
+++ /dev/null
@@ -1,48 +0,0 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6from transformers.models.clip import CLIPTextModel, CLIPTextConfig
7from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
8
9
10def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
11 text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids)
12 text_encoder.text_model.embeddings = text_embeddings
13
14
15class TrainableEmbeddings(CLIPTextEmbeddings):
16 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]):
17 super().__init__(config)
18
19 self.token_embedding = embeddings.token_embedding
20 self.position_embedding = embeddings.position_embedding
21
22 self.train_indices = torch.tensor(new_ids)
23
24 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
25 self.trainable_embedding.weight.data.zero_()
26 self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices]
27
28 def forward(
29 self,
30 input_ids: Optional[torch.LongTensor] = None,
31 position_ids: Optional[torch.LongTensor] = None,
32 inputs_embeds: Optional[torch.FloatTensor] = None,
33 ) -> torch.Tensor:
34 device = input_ids.device
35 seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
36
37 if position_ids is None:
38 position_ids = self.position_ids[:, :seq_length]
39
40 if inputs_embeds is None:
41 mask = torch.isin(input_ids, self.train_indices.to(device))
42 inputs_embeds = self.token_embedding(input_ids)
43 inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask]
44
45 position_embeddings = self.position_embedding(position_ids)
46 embeddings = inputs_embeds + position_embeddings
47
48 return embeddings