diff options
| -rw-r--r-- | models/clip/prompt.py | 6 | ||||
| -rw-r--r-- | train_ti.py | 12 | ||||
| -rw-r--r-- | training/ti.py | 9 |
3 files changed, 13 insertions, 14 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 9b427a0..da33ecf 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
| @@ -27,10 +27,10 @@ class PromptProcessor(): | |||
| 27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): | 27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): |
| 28 | prompts = input_ids.shape[0] | 28 | prompts = input_ids.shape[0] |
| 29 | 29 | ||
| 30 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 30 | input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
| 31 | if attention_mask is not None: | 31 | if attention_mask is not None: |
| 32 | attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 32 | attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
| 33 | 33 | ||
| 34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] | 34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] |
| 35 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) | 35 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
| 36 | return text_embeddings | 36 | return text_embeddings |
diff --git a/train_ti.py b/train_ti.py index bb51dc2..e933c48 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -365,6 +365,7 @@ class Checkpointer(CheckpointerBase): | |||
| 365 | tokenizer, | 365 | tokenizer, |
| 366 | text_encoder, | 366 | text_encoder, |
| 367 | scheduler, | 367 | scheduler, |
| 368 | text_embeddings, | ||
| 368 | instance_identifier, | 369 | instance_identifier, |
| 369 | placeholder_token, | 370 | placeholder_token, |
| 370 | placeholder_token_id, | 371 | placeholder_token_id, |
| @@ -392,6 +393,7 @@ class Checkpointer(CheckpointerBase): | |||
| 392 | self.tokenizer = tokenizer | 393 | self.tokenizer = tokenizer |
| 393 | self.text_encoder = text_encoder | 394 | self.text_encoder = text_encoder |
| 394 | self.scheduler = scheduler | 395 | self.scheduler = scheduler |
| 396 | self.text_embeddings = text_embeddings | ||
| 395 | 397 | ||
| 396 | @torch.no_grad() | 398 | @torch.no_grad() |
| 397 | def checkpoint(self, step, postfix): | 399 | def checkpoint(self, step, postfix): |
| @@ -403,8 +405,10 @@ class Checkpointer(CheckpointerBase): | |||
| 403 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 405 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 404 | 406 | ||
| 405 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
| 408 | training_token_id = self.text_embeddings.id_mapping[placeholder_token_id] | ||
| 409 | |||
| 406 | # Save a checkpoint | 410 | # Save a checkpoint |
| 407 | learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] | 411 | learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] |
| 408 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 412 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
| 409 | 413 | ||
| 410 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 414 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
| @@ -543,7 +547,7 @@ def main(): | |||
| 543 | 547 | ||
| 544 | # Initialize the optimizer | 548 | # Initialize the optimizer |
| 545 | optimizer = optimizer_class( | 549 | optimizer = optimizer_class( |
| 546 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings | 550 | text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings |
| 547 | lr=args.learning_rate, | 551 | lr=args.learning_rate, |
| 548 | betas=(args.adam_beta1, args.adam_beta2), | 552 | betas=(args.adam_beta1, args.adam_beta2), |
| 549 | weight_decay=args.adam_weight_decay, | 553 | weight_decay=args.adam_weight_decay, |
| @@ -741,6 +745,7 @@ def main(): | |||
| 741 | tokenizer=tokenizer, | 745 | tokenizer=tokenizer, |
| 742 | text_encoder=text_encoder, | 746 | text_encoder=text_encoder, |
| 743 | scheduler=checkpoint_scheduler, | 747 | scheduler=checkpoint_scheduler, |
| 748 | text_embeddings=text_embeddings, | ||
| 744 | instance_identifier=args.instance_identifier, | 749 | instance_identifier=args.instance_identifier, |
| 745 | placeholder_token=args.placeholder_token, | 750 | placeholder_token=args.placeholder_token, |
| 746 | placeholder_token_id=placeholder_token_id, | 751 | placeholder_token_id=placeholder_token_id, |
| @@ -774,7 +779,6 @@ def main(): | |||
| 774 | local_progress_bar.reset() | 779 | local_progress_bar.reset() |
| 775 | 780 | ||
| 776 | text_encoder.train() | 781 | text_encoder.train() |
| 777 | train_loss = 0.0 | ||
| 778 | 782 | ||
| 779 | for step, batch in enumerate(train_dataloader): | 783 | for step, batch in enumerate(train_dataloader): |
| 780 | with accelerator.accumulate(text_encoder): | 784 | with accelerator.accumulate(text_encoder): |
| @@ -834,8 +838,6 @@ def main(): | |||
| 834 | lr_scheduler.step() | 838 | lr_scheduler.step() |
| 835 | optimizer.zero_grad(set_to_none=True) | 839 | optimizer.zero_grad(set_to_none=True) |
| 836 | 840 | ||
| 837 | text_embeddings.save() | ||
| 838 | |||
| 839 | avg_loss.update(loss.detach_(), bsz) | 841 | avg_loss.update(loss.detach_(), bsz) |
| 840 | avg_acc.update(acc.detach_(), bsz) | 842 | avg_acc.update(acc.detach_(), bsz) |
| 841 | 843 | ||
diff --git a/training/ti.py b/training/ti.py index a5fd8e4..2efd2f2 100644 --- a/training/ti.py +++ b/training/ti.py | |||
| @@ -19,8 +19,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): |
| 20 | super().__init__(config) | 20 | super().__init__(config) |
| 21 | 21 | ||
| 22 | self.token_embedding.requires_grad_(False) | 22 | self.token_embedding.weight.requires_grad = False |
| 23 | self.position_embedding.requires_grad_(False) | 23 | self.position_embedding.weight.requires_grad = False |
| 24 | 24 | ||
| 25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | 25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} |
| 26 | 26 | ||
| @@ -28,6 +28,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] | 28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] |
| 29 | 29 | ||
| 30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) | 30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) |
| 31 | self.trainable_embedding.weight.requires_grad = True | ||
| 31 | 32 | ||
| 32 | def forward( | 33 | def forward( |
| 33 | self, | 34 | self, |
| @@ -64,7 +65,3 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
| 64 | embeddings = inputs_embeds + position_embeddings | 65 | embeddings = inputs_embeds + position_embeddings |
| 65 | 66 | ||
| 66 | return embeddings | 67 | return embeddings |
| 67 | |||
| 68 | @torch.no_grad() | ||
| 69 | def save(self): | ||
| 70 | self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data | ||
