From 1af6c15f795b5ba4df9179d8c59c6b595040a33f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Dec 2022 14:59:00 +0100 Subject: Update --- training/ti.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'training/ti.py') diff --git a/training/ti.py b/training/ti.py index 1318e22..031fe48 100644 --- a/training/ti.py +++ b/training/ti.py @@ -22,7 +22,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings): self.train_indices = torch.tensor(new_ids) self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) - self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() + self.trainable_embedding.weight.data.zero_() + self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices] def forward( self, -- cgit v1.2.3-54-g00ecf