diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-24 15:16:19 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-24 15:16:19 +0100 |
| commit | 92e5cd4563a62413e72370884c50fb1ab2a91854 (patch) | |
| tree | 4a0d12c5cddb266a6881f97aa93f9065e29d1ee4 | |
| parent | Fixed Textual Inversion (diff) | |
| download | textual-inversion-diff-92e5cd4563a62413e72370884c50fb1ab2a91854.tar.gz textual-inversion-diff-92e5cd4563a62413e72370884c50fb1ab2a91854.tar.bz2 textual-inversion-diff-92e5cd4563a62413e72370884c50fb1ab2a91854.zip | |
Update
| -rw-r--r-- | training/ti.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/training/ti.py b/training/ti.py index dc33e5e..1318e22 100644 --- a/training/ti.py +++ b/training/ti.py | |||
| @@ -16,12 +16,12 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 16 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): | 16 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): |
| 17 | super().__init__(config) | 17 | super().__init__(config) |
| 18 | 18 | ||
| 19 | self.token_embedding = embeddings.token_embedding | ||
| 20 | self.position_embedding = embeddings.position_embedding | ||
| 21 | |||
| 19 | self.train_indices = torch.tensor(new_ids) | 22 | self.train_indices = torch.tensor(new_ids) |
| 20 | 23 | ||
| 21 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | 24 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) |
| 22 | |||
| 23 | self.token_embedding = embeddings.token_embedding | ||
| 24 | self.position_embedding = embeddings.position_embedding | ||
| 25 | self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() | 25 | self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() |
| 26 | 26 | ||
| 27 | def forward( | 27 | def forward( |
