summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-22 21:15:24 +0100
committerVolpeon <git@volpeon.ink>2022-12-22 21:15:24 +0100
commitee9a2777c15d4ceea7ef40802b9a21881f6428a8 (patch)
tree20c8b89d58fdd1ec5fc9b3f1cb7a515d6ad78a79 /train_ti.py
parentImproved Textual Inversion: Completely exclude untrained embeddings from trai... (diff)
downloadtextual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.gz
textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.tar.bz2
textual-inversion-diff-ee9a2777c15d4ceea7ef40802b9a21881f6428a8.zip
Fixed Textual Inversion
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py
index bb51dc2..e933c48 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -365,6 +365,7 @@ class Checkpointer(CheckpointerBase):
365 tokenizer, 365 tokenizer,
366 text_encoder, 366 text_encoder,
367 scheduler, 367 scheduler,
368 text_embeddings,
368 instance_identifier, 369 instance_identifier,
369 placeholder_token, 370 placeholder_token,
370 placeholder_token_id, 371 placeholder_token_id,
@@ -392,6 +393,7 @@ class Checkpointer(CheckpointerBase):
392 self.tokenizer = tokenizer 393 self.tokenizer = tokenizer
393 self.text_encoder = text_encoder 394 self.text_encoder = text_encoder
394 self.scheduler = scheduler 395 self.scheduler = scheduler
396 self.text_embeddings = text_embeddings
395 397
396 @torch.no_grad() 398 @torch.no_grad()
397 def checkpoint(self, step, postfix): 399 def checkpoint(self, step, postfix):
@@ -403,8 +405,10 @@ class Checkpointer(CheckpointerBase):
403 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 405 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
404 406
405 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 407 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
408 training_token_id = self.text_embeddings.id_mapping[placeholder_token_id]
409
406 # Save a checkpoint 410 # Save a checkpoint
407 learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] 411 learned_embeds = self.text_embeddings.trainable_embedding.weight[training_token_id]
408 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 412 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
409 413
410 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 414 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
@@ -543,7 +547,7 @@ def main():
543 547
544 # Initialize the optimizer 548 # Initialize the optimizer
545 optimizer = optimizer_class( 549 optimizer = optimizer_class(
546 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings 550 text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings
547 lr=args.learning_rate, 551 lr=args.learning_rate,
548 betas=(args.adam_beta1, args.adam_beta2), 552 betas=(args.adam_beta1, args.adam_beta2),
549 weight_decay=args.adam_weight_decay, 553 weight_decay=args.adam_weight_decay,
@@ -741,6 +745,7 @@ def main():
741 tokenizer=tokenizer, 745 tokenizer=tokenizer,
742 text_encoder=text_encoder, 746 text_encoder=text_encoder,
743 scheduler=checkpoint_scheduler, 747 scheduler=checkpoint_scheduler,
748 text_embeddings=text_embeddings,
744 instance_identifier=args.instance_identifier, 749 instance_identifier=args.instance_identifier,
745 placeholder_token=args.placeholder_token, 750 placeholder_token=args.placeholder_token,
746 placeholder_token_id=placeholder_token_id, 751 placeholder_token_id=placeholder_token_id,
@@ -774,7 +779,6 @@ def main():
774 local_progress_bar.reset() 779 local_progress_bar.reset()
775 780
776 text_encoder.train() 781 text_encoder.train()
777 train_loss = 0.0
778 782
779 for step, batch in enumerate(train_dataloader): 783 for step, batch in enumerate(train_dataloader):
780 with accelerator.accumulate(text_encoder): 784 with accelerator.accumulate(text_encoder):
@@ -834,8 +838,6 @@ def main():
834 lr_scheduler.step() 838 lr_scheduler.step()
835 optimizer.zero_grad(set_to_none=True) 839 optimizer.zero_grad(set_to_none=True)
836 840
837 text_embeddings.save()
838
839 avg_loss.update(loss.detach_(), bsz) 841 avg_loss.update(loss.detach_(), bsz)
840 avg_acc.update(acc.detach_(), bsz) 842 avg_acc.update(acc.detach_(), bsz)
841 843