summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-22 21:15:24 +0100
committerVolpeon <git@volpeon.ink>2022-12-22 21:15:24 +0100
commitee9a2777c15d4ceea7ef40802b9a21881f6428a8 (patch)
tree20c8b89d58fdd1ec5fc9b3f1cb7a515d6ad78a79 /training
parentImproved Textual Inversion: Completely exclude untrained embeddings from trai... (diff)
downloadtextual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.gz
textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.bz2
textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.zip
Fixed Textual Inversion
Diffstat (limited to 'training')
-rw-r--r--training/ti.py9
1 files changed, 3 insertions, 6 deletions
diff --git a/training/ti.py b/training/ti.py
index a5fd8e4..2efd2f2 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -19,8 +19,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
19 def __init__(self, config: CLIPTextConfig, new_ids: list[int]): 19 def __init__(self, config: CLIPTextConfig, new_ids: list[int]):
20 super().__init__(config) 20 super().__init__(config)
21 21
22 self.token_embedding.requires_grad_(False) 22 self.token_embedding.weight.requires_grad = False
23 self.position_embedding.requires_grad_(False) 23 self.position_embedding.weight.requires_grad = False
24 24
25 self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} 25 self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))}
26 26
@@ -28,6 +28,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
28 self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] 28 self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))]
29 29
30 self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) 30 self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices])
31 self.trainable_embedding.weight.requires_grad = True
31 32
32 def forward( 33 def forward(
33 self, 34 self,
@@ -64,7 +65,3 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
64 embeddings = inputs_embeds + position_embeddings 65 embeddings = inputs_embeds + position_embeddings
65 66
66 return embeddings 67 return embeddings
67
68 @torch.no_grad()
69 def save(self):
70 self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data