summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-24 15:16:19 +0100
committerVolpeon <git@volpeon.ink>2022-12-24 15:16:19 +0100
commit92e5cd4563a62413e72370884c50fb1ab2a91854 (patch)
tree4a0d12c5cddb266a6881f97aa93f9065e29d1ee4
parentFixed Textual Inversion (diff)
downloadtextual-inversion-diff-92e5cd4563a62413e72370884c50fb1ab2a91854.tar.gz
textual-inversion-diff-92e5cd4563a62413e72370884c50fb1ab2a91854.tar.bz2
textual-inversion-diff-92e5cd4563a62413e72370884c50fb1ab2a91854.zip
Update
-rw-r--r--training/ti.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/training/ti.py b/training/ti.py
index dc33e5e..1318e22 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -16,12 +16,12 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
16 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): 16 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]):
17 super().__init__(config) 17 super().__init__(config)
18 18
19 self.token_embedding = embeddings.token_embedding
20 self.position_embedding = embeddings.position_embedding
21
19 self.train_indices = torch.tensor(new_ids) 22 self.train_indices = torch.tensor(new_ids)
20 23
21 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) 24 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
22
23 self.token_embedding = embeddings.token_embedding
24 self.position_embedding = embeddings.position_embedding
25 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() 25 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone()
26 26
27 def forward( 27 def forward(