summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 10:26:17 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 10:26:17 +0100
commit4ef37d87d5a04bc6bb7dacee0660bba3057cc02f (patch)
treeffcae0352c1b0784b18572f66bba3c617165cef1 /train_ti.py
parentVarious cleanups (diff)
downloadtextual-inversion-diff-4ef37d87d5a04bc6bb7dacee0660bba3057cc02f.tar.gz
textual-inversion-diff-4ef37d87d5a04bc6bb7dacee0660bba3057cc02f.tar.bz2
textual-inversion-diff-4ef37d87d5a04bc6bb7dacee0660bba3057cc02f.zip
Fix
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py
index 5df6850..164cf67 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -450,7 +450,8 @@ class Checkpointer(CheckpointerBase):
450 tokenizer, 450 tokenizer,
451 text_encoder, 451 text_encoder,
452 scheduler, 452 scheduler,
453 new_tokens, 453 placeholder_token,
454 new_ids,
454 output_dir: Path, 455 output_dir: Path,
455 sample_image_size, 456 sample_image_size,
456 sample_batches, 457 sample_batches,
@@ -473,7 +474,8 @@ class Checkpointer(CheckpointerBase):
473 self.tokenizer = tokenizer 474 self.tokenizer = tokenizer
474 self.text_encoder = text_encoder 475 self.text_encoder = text_encoder
475 self.scheduler = scheduler 476 self.scheduler = scheduler
476 self.new_tokens = new_tokens 477 self.placeholder_token = placeholder_token
478 self.new_ids = new_ids
477 479
478 @torch.no_grad() 480 @torch.no_grad()
479 def checkpoint(self, step, postfix): 481 def checkpoint(self, step, postfix):
@@ -484,10 +486,10 @@ class Checkpointer(CheckpointerBase):
484 486
485 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 487 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
486 488
487 for new_token in self.new_tokens: 489 for (token, ids) in zip(self.placeholder_token, self.new_ids):
488 text_encoder.text_model.embeddings.save_embed( 490 text_encoder.text_model.embeddings.save_embed(
489 new_token.ids, 491 ids,
490 checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") 492 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
491 ) 493 )
492 494
493 del text_encoder 495 del text_encoder
@@ -572,7 +574,7 @@ def main():
572 raise ValueError("--embeddings_dir must point to an existing directory") 574 raise ValueError("--embeddings_dir must point to an existing directory")
573 575
574 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 576 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
575 print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") 577 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
576 578
577 # Convert the initializer_token, placeholder_token to ids 579 # Convert the initializer_token, placeholder_token to ids
578 initializer_token_ids = [ 580 initializer_token_ids = [
@@ -588,7 +590,7 @@ def main():
588 for (new_id, init_ids) in zip(new_ids, initializer_token_ids) 590 for (new_id, init_ids) in zip(new_ids, initializer_token_ids)
589 ] 591 ]
590 592
591 print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") 593 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}")
592 594
593 vae.requires_grad_(False) 595 vae.requires_grad_(False)
594 unet.requires_grad_(False) 596 unet.requires_grad_(False)
@@ -882,7 +884,8 @@ def main():
882 tokenizer=tokenizer, 884 tokenizer=tokenizer,
883 text_encoder=text_encoder, 885 text_encoder=text_encoder,
884 scheduler=checkpoint_scheduler, 886 scheduler=checkpoint_scheduler,
885 new_tokens=new_tokens, 887 placeholder_token=args.placeholder_token,
888 new_ids=new_ids,
886 output_dir=basepath, 889 output_dir=basepath,
887 sample_image_size=args.sample_image_size, 890 sample_image_size=args.sample_image_size,
888 sample_batch_size=args.sample_batch_size, 891 sample_batch_size=args.sample_batch_size,