summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/prompt.py6
-rw-r--r--train_ti.py12
-rw-r--r--training/ti.py9
3 files changed, 13 insertions, 14 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index 9b427a0..da33ecf 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -27,10 +27,10 @@ class PromptProcessor():
27 def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): 27 def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None):
28 prompts = input_ids.shape[0] 28 prompts = input_ids.shape[0]
29 29
30 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 30 input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
31 if attention_mask is not None: 31 if attention_mask is not None:
32 attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 32 attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
33 33
34 text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] 34 text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0]
35 text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) 35 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
36 return text_embeddings 36 return text_embeddings
diff --git a/train_ti.py b/train_ti.py
index bb51dc2..e933c48 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -365,6 +365,7 @@ class Checkpointer(CheckpointerBase):
365 tokenizer, 365 tokenizer,
366 text_encoder, 366 text_encoder,
367 scheduler, 367 scheduler,
368 text_embeddings,
368 instance_identifier, 369 instance_identifier,
369 placeholder_token, 370 placeholder_token,
370 placeholder_token_id, 371 placeholder_token_id,
@@ -392,6 +393,7 @@ class Checkpointer(CheckpointerBase):
392 self.tokenizer = tokenizer 393 self.tokenizer = tokenizer
393 self.text_encoder = text_encoder 394 self.text_encoder = text_encoder
394 self.scheduler = scheduler 395 self.scheduler = scheduler
396 self.text_embeddings = text_embeddings
395 397
396 @torch.no_grad() 398 @torch.no_grad()
397 def checkpoint(self, step, postfix): 399 def checkpoint(self, step, postfix):
@@ -403,8 +405,10 @@ class Checkpointer(CheckpointerBase):
403 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 405 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
404 406
405 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 407 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
408 training_token_id = self.text_embeddings.id_mapping[placeholder_token_id]
409
406 # Save a checkpoint 410 # Save a checkpoint
407 learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] 411 learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id]
408 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 412 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
409 413
410 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 414 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
@@ -543,7 +547,7 @@ def main():
543 547
544 # Initialize the optimizer 548 # Initialize the optimizer
545 optimizer = optimizer_class( 549 optimizer = optimizer_class(
546 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings 550 text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings
547 lr=args.learning_rate, 551 lr=args.learning_rate,
548 betas=(args.adam_beta1, args.adam_beta2), 552 betas=(args.adam_beta1, args.adam_beta2),
549 weight_decay=args.adam_weight_decay, 553 weight_decay=args.adam_weight_decay,
@@ -741,6 +745,7 @@ def main():
741 tokenizer=tokenizer, 745 tokenizer=tokenizer,
742 text_encoder=text_encoder, 746 text_encoder=text_encoder,
743 scheduler=checkpoint_scheduler, 747 scheduler=checkpoint_scheduler,
748 text_embeddings=text_embeddings,
744 instance_identifier=args.instance_identifier, 749 instance_identifier=args.instance_identifier,
745 placeholder_token=args.placeholder_token, 750 placeholder_token=args.placeholder_token,
746 placeholder_token_id=placeholder_token_id, 751 placeholder_token_id=placeholder_token_id,
@@ -774,7 +779,6 @@ def main():
774 local_progress_bar.reset() 779 local_progress_bar.reset()
775 780
776 text_encoder.train() 781 text_encoder.train()
777 train_loss = 0.0
778 782
779 for step, batch in enumerate(train_dataloader): 783 for step, batch in enumerate(train_dataloader):
780 with accelerator.accumulate(text_encoder): 784 with accelerator.accumulate(text_encoder):
@@ -834,8 +838,6 @@ def main():
834 lr_scheduler.step() 838 lr_scheduler.step()
835 optimizer.zero_grad(set_to_none=True) 839 optimizer.zero_grad(set_to_none=True)
836 840
837 text_embeddings.save()
838
839 avg_loss.update(loss.detach_(), bsz) 841 avg_loss.update(loss.detach_(), bsz)
840 avg_acc.update(acc.detach_(), bsz) 842 avg_acc.update(acc.detach_(), bsz)
841 843
diff --git a/training/ti.py b/training/ti.py
index a5fd8e4..2efd2f2 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -19,8 +19,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
19 def __init__(self, config: CLIPTextConfig, new_ids: list[int]): 19 def __init__(self, config: CLIPTextConfig, new_ids: list[int]):
20 super().__init__(config) 20 super().__init__(config)
21 21
22 self.token_embedding.requires_grad_(False) 22 self.token_embedding.weight.requires_grad = False
23 self.position_embedding.requires_grad_(False) 23 self.position_embedding.weight.requires_grad = False
24 24
25 self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} 25 self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))}
26 26
@@ -28,6 +28,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
28 self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] 28 self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))]
29 29
30 self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) 30 self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices])
31 self.trainable_embedding.weight.requires_grad = True
31 32
32 def forward( 33 def forward(
33 self, 34 self,
@@ -64,7 +65,3 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
64 embeddings = inputs_embeds + position_embeddings 65 embeddings = inputs_embeds + position_embeddings
65 66
66 return embeddings 67 return embeddings
67
68 @torch.no_grad()
69 def save(self):
70 self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data