summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-10 10:34:12 +0200
committerVolpeon <git@volpeon.ink>2023-04-10 10:34:12 +0200
commiteb6a92abda5893c975437026cdaf0ce0bfefe2a4 (patch)
treea1525010b48362986e0cc2b7c3f7505a35dea71a /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-eb6a92abda5893c975437026cdaf0ce0bfefe2a4.tar.gz
textual-inversion-diff-eb6a92abda5893c975437026cdaf0ce0bfefe2a4.tar.bz2
textual-inversion-diff-eb6a92abda5893c975437026cdaf0ce0bfefe2a4.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py68
1 files changed, 41 insertions, 27 deletions
diff --git a/train_ti.py b/train_ti.py
index ebac302..eb08bda 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -152,6 +152,11 @@ def parse_args():
152 help="The embeddings directory where Textual Inversion embeddings are stored.", 152 help="The embeddings directory where Textual Inversion embeddings are stored.",
153 ) 153 )
154 parser.add_argument( 154 parser.add_argument(
155 "--train_dir_embeddings",
156 action="store_true",
157 help="Train embeddings loaded from embeddings directory.",
158 )
159 parser.add_argument(
155 "--collection", 160 "--collection",
156 type=str, 161 type=str,
157 nargs='*', 162 nargs='*',
@@ -404,6 +409,12 @@ def parse_args():
404 help="If checkpoints are saved on maximum accuracy", 409 help="If checkpoints are saved on maximum accuracy",
405 ) 410 )
406 parser.add_argument( 411 parser.add_argument(
412 "--sample_num",
413 type=int,
414 default=None,
415 help="How often to save a checkpoint and sample image (in number of samples)",
416 )
417 parser.add_argument(
407 "--sample_frequency", 418 "--sample_frequency",
408 type=int, 419 type=int,
409 default=1, 420 default=1,
@@ -669,9 +680,14 @@ def main():
669 raise ValueError("--embeddings_dir must point to an existing directory") 680 raise ValueError("--embeddings_dir must point to an existing directory")
670 681
671 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 682 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
672 embeddings.persist()
673 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 683 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
674 684
685 if args.train_dir_embeddings:
686 args.placeholder_tokens = added_tokens
687 print("Training embeddings from embeddings dir")
688 else:
689 embeddings.persist()
690
675 if args.scale_lr: 691 if args.scale_lr:
676 args.learning_rate = ( 692 args.learning_rate = (
677 args.learning_rate * args.gradient_accumulation_steps * 693 args.learning_rate * args.gradient_accumulation_steps *
@@ -852,28 +868,8 @@ def main():
852 args.num_train_steps / len(datamodule.train_dataset) 868 args.num_train_steps / len(datamodule.train_dataset)
853 ) * args.gradient_accumulation_steps 869 ) * args.gradient_accumulation_steps
854 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 870 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
855 871 if args.sample_num is not None:
856 optimizer = create_optimizer( 872 sample_frequency = math.ceil(num_train_epochs / args.sample_num)
857 text_encoder.text_model.embeddings.token_override_embedding.parameters(),
858 lr=args.learning_rate,
859 )
860
861 lr_scheduler = get_scheduler(
862 args.lr_scheduler,
863 optimizer=optimizer,
864 num_training_steps_per_epoch=len(datamodule.train_dataloader),
865 gradient_accumulation_steps=args.gradient_accumulation_steps,
866 min_lr=args.lr_min_lr,
867 warmup_func=args.lr_warmup_func,
868 annealing_func=args.lr_annealing_func,
869 warmup_exp=args.lr_warmup_exp,
870 annealing_exp=args.lr_annealing_exp,
871 cycles=args.lr_cycles,
872 end_lr=1e3,
873 train_epochs=num_train_epochs,
874 warmup_epochs=args.lr_warmup_epochs,
875 mid_point=args.lr_mid_point,
876 )
877 873
878 training_iter = 0 874 training_iter = 0
879 875
@@ -888,6 +884,28 @@ def main():
888 print(f"------------ TI cycle {training_iter} ------------") 884 print(f"------------ TI cycle {training_iter} ------------")
889 print("") 885 print("")
890 886
887 optimizer = create_optimizer(
888 text_encoder.text_model.embeddings.token_override_embedding.parameters(),
889 lr=args.learning_rate,
890 )
891
892 lr_scheduler = get_scheduler(
893 args.lr_scheduler,
894 optimizer=optimizer,
895 num_training_steps_per_epoch=len(datamodule.train_dataloader),
896 gradient_accumulation_steps=args.gradient_accumulation_steps,
897 min_lr=args.lr_min_lr,
898 warmup_func=args.lr_warmup_func,
899 annealing_func=args.lr_annealing_func,
900 warmup_exp=args.lr_warmup_exp,
901 annealing_exp=args.lr_annealing_exp,
902 cycles=args.lr_cycles,
903 end_lr=1e3,
904 train_epochs=num_train_epochs,
905 warmup_epochs=args.lr_warmup_epochs,
906 mid_point=args.lr_mid_point,
907 )
908
891 project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" 909 project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}"
892 sample_output_dir = output_dir / project / "samples" 910 sample_output_dir = output_dir / project / "samples"
893 checkpoint_output_dir = output_dir / project / "checkpoints" 911 checkpoint_output_dir = output_dir / project / "checkpoints"
@@ -908,10 +926,6 @@ def main():
908 placeholder_token_ids=placeholder_token_ids, 926 placeholder_token_ids=placeholder_token_ids,
909 ) 927 )
910 928
911 response = input("Run another cycle? [y/n] ")
912 continue_training = response.lower().strip() != "n"
913 training_iter += 1
914
915 if not args.sequential: 929 if not args.sequential:
916 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 930 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
917 else: 931 else: