From d2105d96fdd18da035d2ad412e3fb6f579d5571a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Dec 2022 14:30:15 +0100 Subject: Fixed Textual Inversion --- training/ti.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) (limited to 'training') diff --git a/training/ti.py b/training/ti.py index 8b2fdd6..dc33e5e 100644 --- a/training/ti.py +++ b/training/ti.py @@ -8,26 +8,21 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): - text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) - - text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding - text_embeddings.token_embedding.weight.requires_grad = False - - text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding - text_embeddings.position_embedding.weight.requires_grad = False - + text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids) text_encoder.text_model.embeddings = text_embeddings class TrainableEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, new_ids: list[int]): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): super().__init__(config) self.train_indices = torch.tensor(new_ids) self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) + + self.token_embedding = embeddings.token_embedding + self.position_embedding = embeddings.position_embedding self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() - self.trainable_embedding.weight.requires_grad = True def forward( self, -- cgit v1.2.3-54-g00ecf