diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/ti.py | 15 |
1 files changed, 5 insertions, 10 deletions
diff --git a/training/ti.py b/training/ti.py index 8b2fdd6..dc33e5e 100644 --- a/training/ti.py +++ b/training/ti.py | |||
@@ -8,26 +8,21 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | |||
8 | 8 | ||
9 | 9 | ||
10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | 10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): |
11 | text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) | 11 | text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids) |
12 | |||
13 | text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding | ||
14 | text_embeddings.token_embedding.weight.requires_grad = False | ||
15 | |||
16 | text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding | ||
17 | text_embeddings.position_embedding.weight.requires_grad = False | ||
18 | |||
19 | text_encoder.text_model.embeddings = text_embeddings | 12 | text_encoder.text_model.embeddings = text_embeddings |
20 | 13 | ||
21 | 14 | ||
22 | class TrainableEmbeddings(CLIPTextEmbeddings): | 15 | class TrainableEmbeddings(CLIPTextEmbeddings): |
23 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 16 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): |
24 | super().__init__(config) | 17 | super().__init__(config) |
25 | 18 | ||
26 | self.train_indices = torch.tensor(new_ids) | 19 | self.train_indices = torch.tensor(new_ids) |
27 | 20 | ||
28 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | 21 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) |
22 | |||
23 | self.token_embedding = embeddings.token_embedding | ||
24 | self.position_embedding = embeddings.position_embedding | ||
29 | self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() | 25 | self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() |
30 | self.trainable_embedding.weight.requires_grad = True | ||
31 | 26 | ||
32 | def forward( | 27 | def forward( |
33 | self, | 28 | self, |