summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/ti.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/training/ti.py b/training/ti.py
index dc33e5e..1318e22 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -16,12 +16,12 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
16 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): 16 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]):
17 super().__init__(config) 17 super().__init__(config)
18 18
19 self.token_embedding = embeddings.token_embedding
20 self.position_embedding = embeddings.position_embedding
21
19 self.train_indices = torch.tensor(new_ids) 22 self.train_indices = torch.tensor(new_ids)
20 23
21 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) 24 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
22
23 self.token_embedding = embeddings.token_embedding
24 self.position_embedding = embeddings.position_embedding
25 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() 25 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone()
26 26
27 def forward( 27 def forward(