From ee9a2777c15d4ceea7ef40802b9a21881f6428a8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Dec 2022 21:15:24 +0100 Subject: Fixed Textual Inversion --- training/ti.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) (limited to 'training') diff --git a/training/ti.py b/training/ti.py index a5fd8e4..2efd2f2 100644 --- a/training/ti.py +++ b/training/ti.py @@ -19,8 +19,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings): def __init__(self, config: CLIPTextConfig, new_ids: list[int]): super().__init__(config) - self.token_embedding.requires_grad_(False) - self.position_embedding.requires_grad_(False) + self.token_embedding.weight.requires_grad = False + self.position_embedding.weight.requires_grad = False self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} @@ -28,6 +28,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings): self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) + self.trainable_embedding.weight.requires_grad = True def forward( self, @@ -64,7 +65,3 @@ class TrainableEmbeddings(CLIPTextEmbeddings): embeddings = inputs_embeds + position_embeddings return embeddings - - @torch.no_grad() - def save(self): - self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data -- cgit v1.2.3-70-g09d2