summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-09 11:29:31 +0200
committerVolpeon <git@volpeon.ink>2023-04-09 11:29:31 +0200
commitba9fd1a10746d85d2502c8a79ac49db63d346b04 (patch)
tree568bf65a0a4dcea2c34de4006b5761d0d6564307 /train_ti.py
parentFix (diff)
downloadtextual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.tar.gz
textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.tar.bz2
textual-inversion-diff-ba9fd1a10746d85d2502c8a79ac49db63d346b04.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py66
1 files changed, 42 insertions, 24 deletions
diff --git a/train_ti.py b/train_ti.py
index daf8bc5..2d51800 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -458,6 +458,12 @@ def parse_args():
458 help="The weight of prior preservation loss." 458 help="The weight of prior preservation loss."
459 ) 459 )
460 parser.add_argument( 460 parser.add_argument(
461 "--emb_dropout",
462 type=float,
463 default=0,
464 help="Embedding dropout probability.",
465 )
466 parser.add_argument(
461 "--use_emb_decay", 467 "--use_emb_decay",
462 action="store_true", 468 action="store_true",
463 help="Whether to use embedding decay." 469 help="Whether to use embedding decay."
@@ -624,7 +630,7 @@ def main():
624 save_args(output_dir, args) 630 save_args(output_dir, args)
625 631
626 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 632 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
627 args.pretrained_model_name_or_path) 633 args.pretrained_model_name_or_path, args.emb_dropout)
628 634
629 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 635 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
630 tokenizer.set_dropout(args.vector_dropout) 636 tokenizer.set_dropout(args.vector_dropout)
@@ -755,8 +761,6 @@ def main():
755 else: 761 else:
756 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 762 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
757 763
758 checkpoint_output_dir = output_dir / "checkpoints"
759
760 trainer = partial( 764 trainer = partial(
761 train, 765 train,
762 accelerator=accelerator, 766 accelerator=accelerator,
@@ -777,7 +781,6 @@ def main():
777 global_step_offset=global_step_offset, 781 global_step_offset=global_step_offset,
778 offset_noise_strength=args.offset_noise_strength, 782 offset_noise_strength=args.offset_noise_strength,
779 # -- 783 # --
780 checkpoint_output_dir=checkpoint_output_dir,
781 use_emb_decay=args.use_emb_decay, 784 use_emb_decay=args.use_emb_decay,
782 emb_decay_target=args.emb_decay_target, 785 emb_decay_target=args.emb_decay_target,
783 emb_decay=args.emb_decay, 786 emb_decay=args.emb_decay,
@@ -793,11 +796,6 @@ def main():
793 ) 796 )
794 797
795 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): 798 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str):
796 if len(placeholder_tokens) == 1:
797 sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}"
798 else:
799 sample_output_dir = output_dir / "samples"
800
801 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 799 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
802 tokenizer=tokenizer, 800 tokenizer=tokenizer,
803 embeddings=embeddings, 801 embeddings=embeddings,
@@ -809,7 +807,11 @@ def main():
809 807
810 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) 808 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids))
811 809
812 print(f"{i + 1}: {stats}") 810 print("")
811 print(f"============ TI batch {i + 1} ============")
812 print("")
813 print(stats)
814 print("")
813 815
814 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] 816 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens]
815 817
@@ -868,20 +870,36 @@ def main():
868 mid_point=args.lr_mid_point, 870 mid_point=args.lr_mid_point,
869 ) 871 )
870 872
871 trainer( 873 continue_training = True
872 project="textual_inversion", 874 training_iter = 1
873 train_dataloader=datamodule.train_dataloader, 875
874 val_dataloader=datamodule.val_dataloader, 876 while continue_training:
875 optimizer=optimizer, 877 print(f"------------ TI cycle {training_iter} ------------")
876 lr_scheduler=lr_scheduler, 878 print("")
877 num_train_epochs=num_train_epochs, 879
878 # -- 880 project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}"
879 group_labels=["emb"], 881 sample_output_dir = output_dir / project / "samples"
880 sample_output_dir=sample_output_dir, 882 checkpoint_output_dir = output_dir / project / "checkpoints"
881 sample_frequency=sample_frequency, 883
882 placeholder_tokens=placeholder_tokens, 884 trainer(
883 placeholder_token_ids=placeholder_token_ids, 885 project=project,
884 ) 886 train_dataloader=datamodule.train_dataloader,
887 val_dataloader=datamodule.val_dataloader,
888 optimizer=optimizer,
889 lr_scheduler=lr_scheduler,
890 num_train_epochs=num_train_epochs,
891 # --
892 group_labels=["emb"],
893 checkpoint_output_dir=checkpoint_output_dir,
894 sample_output_dir=sample_output_dir,
895 sample_frequency=sample_frequency,
896 placeholder_tokens=placeholder_tokens,
897 placeholder_token_ids=placeholder_token_ids,
898 )
899
900 response = input("Run another cycle? [y/n] ")
901 continue_training = response.lower().strip() != "n"
902 training_iter += 1
885 903
886 if not args.sequential: 904 if not args.sequential:
887 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 905 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)