summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/ti.py24
1 files changed, 12 insertions, 12 deletions
diff --git a/training/ti.py b/training/ti.py
index 2efd2f2..2e5139a 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -9,9 +9,15 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
9 9
10def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): 10def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
11 text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) 11 text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids)
12 text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight 12
13 text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight 13 text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding
14 text_embeddings.token_embedding.weight.requires_grad = False
15
16 text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding
17 text_embeddings.position_embedding.weight.requires_grad = False
18
14 text_encoder.text_model.embeddings = text_embeddings 19 text_encoder.text_model.embeddings = text_embeddings
20
15 return text_embeddings 21 return text_embeddings
16 22
17 23
@@ -19,15 +25,12 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
19 def __init__(self, config: CLIPTextConfig, new_ids: list[int]): 25 def __init__(self, config: CLIPTextConfig, new_ids: list[int]):
20 super().__init__(config) 26 super().__init__(config)
21 27
22 self.token_embedding.weight.requires_grad = False
23 self.position_embedding.weight.requires_grad = False
24
25 self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} 28 self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))}
26 29
27 indices = torch.arange(self.token_embedding.num_embeddings) 30 self.train_indices = torch.tensor(new_ids)
28 self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))]
29 31
30 self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) 32 self.trainable_embedding = nn.Embedding(len(new_ids), self.token_embedding.embedding_dim)
33 self.trainable_embedding.weight.data = self.token_embedding.weight.data[self.train_indices]
31 self.trainable_embedding.weight.requires_grad = True 34 self.trainable_embedding.weight.requires_grad = True
32 35
33 def forward( 36 def forward(
@@ -42,10 +45,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
42 position_ids = self.position_ids[:, :seq_length] 45 position_ids = self.position_ids[:, :seq_length]
43 46
44 if inputs_embeds is None: 47 if inputs_embeds is None:
45 mask = torch.isin( 48 mask = torch.isin(input_ids, self.train_indices.to(input_ids.device))[:, :, None]
46 input_ids,
47 self.train_indices.to(input_ids.device)
48 ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim)
49 49
50 trainable_input_ids = torch.tensor([ 50 trainable_input_ids = torch.tensor([
51 [ 51 [