diff options
-rw-r--r-- | training/ti.py | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/training/ti.py b/training/ti.py index 2efd2f2..2e5139a 100644 --- a/training/ti.py +++ b/training/ti.py | |||
@@ -9,9 +9,15 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | |||
9 | 9 | ||
10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | 10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): |
11 | text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) | 11 | text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) |
12 | text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight | 12 | |
13 | text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight | 13 | text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding |
14 | text_embeddings.token_embedding.weight.requires_grad = False | ||
15 | |||
16 | text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding | ||
17 | text_embeddings.position_embedding.weight.requires_grad = False | ||
18 | |||
14 | text_encoder.text_model.embeddings = text_embeddings | 19 | text_encoder.text_model.embeddings = text_embeddings |
20 | |||
15 | return text_embeddings | 21 | return text_embeddings |
16 | 22 | ||
17 | 23 | ||
@@ -19,15 +25,12 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 25 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): |
20 | super().__init__(config) | 26 | super().__init__(config) |
21 | 27 | ||
22 | self.token_embedding.weight.requires_grad = False | ||
23 | self.position_embedding.weight.requires_grad = False | ||
24 | |||
25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | 28 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} |
26 | 29 | ||
27 | indices = torch.arange(self.token_embedding.num_embeddings) | 30 | self.train_indices = torch.tensor(new_ids) |
28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] | ||
29 | 31 | ||
30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) | 32 | self.trainable_embedding = nn.Embedding(len(new_ids), self.token_embedding.embedding_dim) |
33 | self.trainable_embedding.weight.data = self.token_embedding.weight.data[self.train_indices] | ||
31 | self.trainable_embedding.weight.requires_grad = True | 34 | self.trainable_embedding.weight.requires_grad = True |
32 | 35 | ||
33 | def forward( | 36 | def forward( |
@@ -42,10 +45,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
42 | position_ids = self.position_ids[:, :seq_length] | 45 | position_ids = self.position_ids[:, :seq_length] |
43 | 46 | ||
44 | if inputs_embeds is None: | 47 | if inputs_embeds is None: |
45 | mask = torch.isin( | 48 | mask = torch.isin(input_ids, self.train_indices.to(input_ids.device))[:, :, None] |
46 | input_ids, | ||
47 | self.train_indices.to(input_ids.device) | ||
48 | ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim) | ||
49 | 49 | ||
50 | trainable_input_ids = torch.tensor([ | 50 | trainable_input_ids = torch.tensor([ |
51 | [ | 51 | [ |