diff options
author | Volpeon <git@volpeon.ink> | 2022-12-22 21:15:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-22 21:15:24 +0100 |
commit | ee9a2777c15d4ceea7ef40802b9a21881f6428a8 (patch) | |
tree | 20c8b89d58fdd1ec5fc9b3f1cb7a515d6ad78a79 /train_ti.py | |
parent | Improved Textual Inversion: Completely exclude untrained embeddings from trai... (diff) | |
download | textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.gz textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.bz2 textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.zip |
Fixed Textual Inversion
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index bb51dc2..e933c48 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -365,6 +365,7 @@ class Checkpointer(CheckpointerBase): | |||
365 | tokenizer, | 365 | tokenizer, |
366 | text_encoder, | 366 | text_encoder, |
367 | scheduler, | 367 | scheduler, |
368 | text_embeddings, | ||
368 | instance_identifier, | 369 | instance_identifier, |
369 | placeholder_token, | 370 | placeholder_token, |
370 | placeholder_token_id, | 371 | placeholder_token_id, |
@@ -392,6 +393,7 @@ class Checkpointer(CheckpointerBase): | |||
392 | self.tokenizer = tokenizer | 393 | self.tokenizer = tokenizer |
393 | self.text_encoder = text_encoder | 394 | self.text_encoder = text_encoder |
394 | self.scheduler = scheduler | 395 | self.scheduler = scheduler |
396 | self.text_embeddings = text_embeddings | ||
395 | 397 | ||
396 | @torch.no_grad() | 398 | @torch.no_grad() |
397 | def checkpoint(self, step, postfix): | 399 | def checkpoint(self, step, postfix): |
@@ -403,8 +405,10 @@ class Checkpointer(CheckpointerBase): | |||
403 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 405 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
404 | 406 | ||
405 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
408 | training_token_id = self.text_embeddings.id_mapping[placeholder_token_id] | ||
409 | |||
406 | # Save a checkpoint | 410 | # Save a checkpoint |
407 | learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] | 411 | learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id] |
408 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 412 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
409 | 413 | ||
410 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 414 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
@@ -543,7 +547,7 @@ def main(): | |||
543 | 547 | ||
544 | # Initialize the optimizer | 548 | # Initialize the optimizer |
545 | optimizer = optimizer_class( | 549 | optimizer = optimizer_class( |
546 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings | 550 | text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings |
547 | lr=args.learning_rate, | 551 | lr=args.learning_rate, |
548 | betas=(args.adam_beta1, args.adam_beta2), | 552 | betas=(args.adam_beta1, args.adam_beta2), |
549 | weight_decay=args.adam_weight_decay, | 553 | weight_decay=args.adam_weight_decay, |
@@ -741,6 +745,7 @@ def main(): | |||
741 | tokenizer=tokenizer, | 745 | tokenizer=tokenizer, |
742 | text_encoder=text_encoder, | 746 | text_encoder=text_encoder, |
743 | scheduler=checkpoint_scheduler, | 747 | scheduler=checkpoint_scheduler, |
748 | text_embeddings=text_embeddings, | ||
744 | instance_identifier=args.instance_identifier, | 749 | instance_identifier=args.instance_identifier, |
745 | placeholder_token=args.placeholder_token, | 750 | placeholder_token=args.placeholder_token, |
746 | placeholder_token_id=placeholder_token_id, | 751 | placeholder_token_id=placeholder_token_id, |
@@ -774,7 +779,6 @@ def main(): | |||
774 | local_progress_bar.reset() | 779 | local_progress_bar.reset() |
775 | 780 | ||
776 | text_encoder.train() | 781 | text_encoder.train() |
777 | train_loss = 0.0 | ||
778 | 782 | ||
779 | for step, batch in enumerate(train_dataloader): | 783 | for step, batch in enumerate(train_dataloader): |
780 | with accelerator.accumulate(text_encoder): | 784 | with accelerator.accumulate(text_encoder): |
@@ -834,8 +838,6 @@ def main(): | |||
834 | lr_scheduler.step() | 838 | lr_scheduler.step() |
835 | optimizer.zero_grad(set_to_none=True) | 839 | optimizer.zero_grad(set_to_none=True) |
836 | 840 | ||
837 | text_embeddings.save() | ||
838 | |||
839 | avg_loss.update(loss.detach_(), bsz) | 841 | avg_loss.update(loss.detach_(), bsz) |
840 | avg_acc.update(acc.detach_(), bsz) | 842 | avg_acc.update(acc.detach_(), bsz) |
841 | 843 | ||