diff options
author | Volpeon <git@volpeon.ink> | 2022-12-25 14:59:00 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-25 14:59:00 +0100 |
commit | 1af6c15f795b5ba4df9179d8c59c6b595040a33f (patch) | |
tree | fa7c033a6c259b64fa84b5483894150b07c9337f /training/ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-1af6c15f795b5ba4df9179d8c59c6b595040a33f.tar.gz textual-inversion-diff-1af6c15f795b5ba4df9179d8c59c6b595040a33f.tar.bz2 textual-inversion-diff-1af6c15f795b5ba4df9179d8c59c6b595040a33f.zip |
Update
Diffstat (limited to 'training/ti.py')
-rw-r--r-- | training/ti.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/training/ti.py b/training/ti.py index 1318e22..031fe48 100644 --- a/training/ti.py +++ b/training/ti.py | |||
@@ -22,7 +22,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
22 | self.train_indices = torch.tensor(new_ids) | 22 | self.train_indices = torch.tensor(new_ids) |
23 | 23 | ||
24 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | 24 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) |
25 | self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() | 25 | self.trainable_embedding.weight.data.zero_() |
26 | self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices] | ||
26 | 27 | ||
27 | def forward( | 28 | def forward( |
28 | self, | 29 | self, |