diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-22 21:15:24 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-22 21:15:24 +0100 |
| commit | ee9a2777c15d4ceea7ef40802b9a21881f6428a8 (patch) | |
| tree | 20c8b89d58fdd1ec5fc9b3f1cb7a515d6ad78a79 /training | |
| parent | Improved Textual Inversion: Completely exclude untrained embeddings from trai... (diff) | |
| download | textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.gz textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.bz2 textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.zip | |
Fixed Textual Inversion
Diffstat (limited to 'training')
| -rw-r--r-- | training/ti.py | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/training/ti.py b/training/ti.py index a5fd8e4..2efd2f2 100644 --- a/training/ti.py +++ b/training/ti.py | |||
| @@ -19,8 +19,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): |
| 20 | super().__init__(config) | 20 | super().__init__(config) |
| 21 | 21 | ||
| 22 | self.token_embedding.requires_grad_(False) | 22 | self.token_embedding.weight.requires_grad = False |
| 23 | self.position_embedding.requires_grad_(False) | 23 | self.position_embedding.weight.requires_grad = False |
| 24 | 24 | ||
| 25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | 25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} |
| 26 | 26 | ||
| @@ -28,6 +28,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] | 28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] |
| 29 | 29 | ||
| 30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) | 30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) |
| 31 | self.trainable_embedding.weight.requires_grad = True | ||
| 31 | 32 | ||
| 32 | def forward( | 33 | def forward( |
| 33 | self, | 34 | self, |
| @@ -64,7 +65,3 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 64 | embeddings = inputs_embeds + position_embeddings | 65 | embeddings = inputs_embeds + position_embeddings |
| 65 | 66 | ||
| 66 | return embeddings | 67 | return embeddings |
| 67 | |||
| 68 | @torch.no_grad() | ||
| 69 | def save(self): | ||
| 70 | self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data | ||
