diff options
author | Volpeon <git@volpeon.ink> | 2022-12-22 21:15:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-22 21:15:24 +0100 |
commit | ee9a2777c15d4ceea7ef40802b9a21881f6428a8 (patch) | |
tree | 20c8b89d58fdd1ec5fc9b3f1cb7a515d6ad78a79 /training | |
parent | Improved Textual Inversion: Completely exclude untrained embeddings from trai... (diff) | |
download | textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.gz textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.bz2 textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.zip |
Fixed Textual Inversion
Diffstat (limited to 'training')
-rw-r--r-- | training/ti.py | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/training/ti.py b/training/ti.py index a5fd8e4..2efd2f2 100644 --- a/training/ti.py +++ b/training/ti.py | |||
@@ -19,8 +19,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): |
20 | super().__init__(config) | 20 | super().__init__(config) |
21 | 21 | ||
22 | self.token_embedding.requires_grad_(False) | 22 | self.token_embedding.weight.requires_grad = False |
23 | self.position_embedding.requires_grad_(False) | 23 | self.position_embedding.weight.requires_grad = False |
24 | 24 | ||
25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | 25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} |
26 | 26 | ||
@@ -28,6 +28,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] | 28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] |
29 | 29 | ||
30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) | 30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) |
31 | self.trainable_embedding.weight.requires_grad = True | ||
31 | 32 | ||
32 | def forward( | 33 | def forward( |
33 | self, | 34 | self, |
@@ -64,7 +65,3 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
64 | embeddings = inputs_embeds + position_embeddings | 65 | embeddings = inputs_embeds + position_embeddings |
65 | 66 | ||
66 | return embeddings | 67 | return embeddings |
67 | |||
68 | @torch.no_grad() | ||
69 | def save(self): | ||
70 | self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data | ||