diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/ti.py | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/training/ti.py b/training/ti.py index a5fd8e4..2efd2f2 100644 --- a/training/ti.py +++ b/training/ti.py | |||
| @@ -19,8 +19,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): |
| 20 | super().__init__(config) | 20 | super().__init__(config) |
| 21 | 21 | ||
| 22 | self.token_embedding.requires_grad_(False) | 22 | self.token_embedding.weight.requires_grad = False |
| 23 | self.position_embedding.requires_grad_(False) | 23 | self.position_embedding.weight.requires_grad = False |
| 24 | 24 | ||
| 25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | 25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} |
| 26 | 26 | ||
| @@ -28,6 +28,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] | 28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] |
| 29 | 29 | ||
| 30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) | 30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) |
| 31 | self.trainable_embedding.weight.requires_grad = True | ||
| 31 | 32 | ||
| 32 | def forward( | 33 | def forward( |
| 33 | self, | 34 | self, |
| @@ -64,7 +65,3 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 64 | embeddings = inputs_embeds + position_embeddings | 65 | embeddings = inputs_embeds + position_embeddings |
| 65 | 66 | ||
| 66 | return embeddings | 67 | return embeddings |
| 67 | |||
| 68 | @torch.no_grad() | ||
| 69 | def save(self): | ||
| 70 | self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data | ||
