summaryrefslogtreecommitdiffstats
path: root/training/ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-25 14:59:00 +0100
committerVolpeon <git@volpeon.ink>2022-12-25 14:59:00 +0100
commit1af6c15f795b5ba4df9179d8c59c6b595040a33f (patch)
treefa7c033a6c259b64fa84b5483894150b07c9337f /training/ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-1af6c15f795b5ba4df9179d8c59c6b595040a33f.tar.gz
textual-inversion-diff-1af6c15f795b5ba4df9179d8c59c6b595040a33f.tar.bz2
textual-inversion-diff-1af6c15f795b5ba4df9179d8c59c6b595040a33f.zip
Update
Diffstat (limited to 'training/ti.py')
-rw-r--r--training/ti.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/training/ti.py b/training/ti.py
index 1318e22..031fe48 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -22,7 +22,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
22 self.train_indices = torch.tensor(new_ids) 22 self.train_indices = torch.tensor(new_ids)
23 23
24 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) 24 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
25 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() 25 self.trainable_embedding.weight.data.zero_()
26 self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices]
26 27
27 def forward( 28 def forward(
28 self, 29 self,