diff options
author | Volpeon <git@volpeon.ink> | 2022-12-22 21:15:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-22 21:15:24 +0100 |
commit | ee9a2777c15d4ceea7ef40802b9a21881f6428a8 (patch) | |
tree | 20c8b89d58fdd1ec5fc9b3f1cb7a515d6ad78a79 | |
parent | Improved Textual Inversion: Completely exclude untrained embeddings from trai... (diff) | |
download | textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.gz textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.bz2 textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.zip |
Fixed Textual Inversion
-rw-r--r-- | models/clip/prompt.py | 6 | ||||
-rw-r--r-- | train_ti.py | 12 | ||||
-rw-r--r-- | training/ti.py | 9 |
3 files changed, 13 insertions, 14 deletions
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 9b427a0..da33ecf 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
@@ -27,10 +27,10 @@ class PromptProcessor(): | |||
27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): | 27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): |
28 | prompts = input_ids.shape[0] | 28 | prompts = input_ids.shape[0] |
29 | 29 | ||
30 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 30 | input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
31 | if attention_mask is not None: | 31 | if attention_mask is not None: |
32 | attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 32 | attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
33 | 33 | ||
34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] | 34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] |
35 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) | 35 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
36 | return text_embeddings | 36 | return text_embeddings |
diff --git a/train_ti.py b/train_ti.py index bb51dc2..e933c48 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -365,6 +365,7 @@ class Checkpointer(CheckpointerBase): | |||
365 | tokenizer, | 365 | tokenizer, |
366 | text_encoder, | 366 | text_encoder, |
367 | scheduler, | 367 | scheduler, |
368 | text_embeddings, | ||
368 | instance_identifier, | 369 | instance_identifier, |
369 | placeholder_token, | 370 | placeholder_token, |
370 | placeholder_token_id, | 371 | placeholder_token_id, |
@@ -392,6 +393,7 @@ class Checkpointer(CheckpointerBase): | |||
392 | self.tokenizer = tokenizer | 393 | self.tokenizer = tokenizer |
393 | self.text_encoder = text_encoder | 394 | self.text_encoder = text_encoder |
394 | self.scheduler = scheduler | 395 | self.scheduler = scheduler |
396 | self.text_embeddings = text_embeddings | ||
395 | 397 | ||
396 | @torch.no_grad() | 398 | @torch.no_grad() |
397 | def checkpoint(self, step, postfix): | 399 | def checkpoint(self, step, postfix): |
@@ -403,8 +405,10 @@ class Checkpointer(CheckpointerBase): | |||
403 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 405 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
404 | 406 | ||
405 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
408 | training_token_id = self.text_embeddings.id_mapping[placeholder_token_id] | ||
409 | |||
406 | # Save a checkpoint | 410 | # Save a checkpoint |
407 | learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] | 411 | learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] |
408 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 412 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
409 | 413 | ||
410 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 414 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
@@ -543,7 +547,7 @@ def main(): | |||
543 | 547 | ||
544 | # Initialize the optimizer | 548 | # Initialize the optimizer |
545 | optimizer = optimizer_class( | 549 | optimizer = optimizer_class( |
546 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings | 550 | text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings |
547 | lr=args.learning_rate, | 551 | lr=args.learning_rate, |
548 | betas=(args.adam_beta1, args.adam_beta2), | 552 | betas=(args.adam_beta1, args.adam_beta2), |
549 | weight_decay=args.adam_weight_decay, | 553 | weight_decay=args.adam_weight_decay, |
@@ -741,6 +745,7 @@ def main(): | |||
741 | tokenizer=tokenizer, | 745 | tokenizer=tokenizer, |
742 | text_encoder=text_encoder, | 746 | text_encoder=text_encoder, |
743 | scheduler=checkpoint_scheduler, | 747 | scheduler=checkpoint_scheduler, |
748 | text_embeddings=text_embeddings, | ||
744 | instance_identifier=args.instance_identifier, | 749 | instance_identifier=args.instance_identifier, |
745 | placeholder_token=args.placeholder_token, | 750 | placeholder_token=args.placeholder_token, |
746 | placeholder_token_id=placeholder_token_id, | 751 | placeholder_token_id=placeholder_token_id, |
@@ -774,7 +779,6 @@ def main(): | |||
774 | local_progress_bar.reset() | 779 | local_progress_bar.reset() |
775 | 780 | ||
776 | text_encoder.train() | 781 | text_encoder.train() |
777 | train_loss = 0.0 | ||
778 | 782 | ||
779 | for step, batch in enumerate(train_dataloader): | 783 | for step, batch in enumerate(train_dataloader): |
780 | with accelerator.accumulate(text_encoder): | 784 | with accelerator.accumulate(text_encoder): |
@@ -834,8 +838,6 @@ def main(): | |||
834 | lr_scheduler.step() | 838 | lr_scheduler.step() |
835 | optimizer.zero_grad(set_to_none=True) | 839 | optimizer.zero_grad(set_to_none=True) |
836 | 840 | ||
837 | text_embeddings.save() | ||
838 | |||
839 | avg_loss.update(loss.detach_(), bsz) | 841 | avg_loss.update(loss.detach_(), bsz) |
840 | avg_acc.update(acc.detach_(), bsz) | 842 | avg_acc.update(acc.detach_(), bsz) |
841 | 843 | ||
diff --git a/training/ti.py b/training/ti.py index a5fd8e4..2efd2f2 100644 --- a/training/ti.py +++ b/training/ti.py | |||
@@ -19,8 +19,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): |
20 | super().__init__(config) | 20 | super().__init__(config) |
21 | 21 | ||
22 | self.token_embedding.requires_grad_(False) | 22 | self.token_embedding.weight.requires_grad = False |
23 | self.position_embedding.requires_grad_(False) | 23 | self.position_embedding.weight.requires_grad = False |
24 | 24 | ||
25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | 25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} |
26 | 26 | ||
@@ -28,6 +28,7 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] | 28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] |
29 | 29 | ||
30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) | 30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) |
31 | self.trainable_embedding.weight.requires_grad = True | ||
31 | 32 | ||
32 | def forward( | 33 | def forward( |
33 | self, | 34 | self, |
@@ -64,7 +65,3 @@ class TrainableEmbeddings(CLIPTextEmbeddings): | |||
64 | embeddings = inputs_embeds + position_embeddings | 65 | embeddings = inputs_embeds + position_embeddings |
65 | 66 | ||
66 | return embeddings | 67 | return embeddings |
67 | |||
68 | @torch.no_grad() | ||
69 | def save(self): | ||
70 | self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data | ||