From 92e5cd4563a62413e72370884c50fb1ab2a91854 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Dec 2022 15:16:19 +0100 Subject: Update --- training/ti.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'training') diff --git a/training/ti.py b/training/ti.py index dc33e5e..1318e22 100644 --- a/training/ti.py +++ b/training/ti.py @@ -16,12 +16,12 @@ class TrainableEmbeddings(CLIPTextEmbeddings): def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): super().__init__(config) + self.token_embedding = embeddings.token_embedding + self.position_embedding = embeddings.position_embedding + self.train_indices = torch.tensor(new_ids) self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) - - self.token_embedding = embeddings.token_embedding - self.position_embedding = embeddings.position_embedding self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() def forward( -- cgit v1.2.3-70-g09d2