summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/ti.py15
1 files changed, 5 insertions, 10 deletions
diff --git a/training/ti.py b/training/ti.py
index 8b2fdd6..dc33e5e 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -8,26 +8,21 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
8 8
9 9
10def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): 10def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
11 text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) 11 text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids)
12
13 text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding
14 text_embeddings.token_embedding.weight.requires_grad = False
15
16 text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding
17 text_embeddings.position_embedding.weight.requires_grad = False
18
19 text_encoder.text_model.embeddings = text_embeddings 12 text_encoder.text_model.embeddings = text_embeddings
20 13
21 14
22class TrainableEmbeddings(CLIPTextEmbeddings): 15class TrainableEmbeddings(CLIPTextEmbeddings):
23 def __init__(self, config: CLIPTextConfig, new_ids: list[int]): 16 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]):
24 super().__init__(config) 17 super().__init__(config)
25 18
26 self.train_indices = torch.tensor(new_ids) 19 self.train_indices = torch.tensor(new_ids)
27 20
28 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) 21 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
22
23 self.token_embedding = embeddings.token_embedding
24 self.position_embedding = embeddings.position_embedding
29 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() 25 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone()
30 self.trainable_embedding.weight.requires_grad = True
31 26
32 def forward( 27 def forward(
33 self, 28 self,