summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-27 08:39:33 +0200
committerVolpeon <git@volpeon.ink>2023-04-27 08:39:33 +0200
commitfa4b797914bc9997039562afefd64620bf235d60 (patch)
tree5fa851380bb9149b273f73c313f27103505e967a
parentUpdate (diff)
downloadtextual-inversion-diff-fa4b797914bc9997039562afefd64620bf235d60.tar.gz
textual-inversion-diff-fa4b797914bc9997039562afefd64620bf235d60.tar.bz2
textual-inversion-diff-fa4b797914bc9997039562afefd64620bf235d60.zip
Fix
-rw-r--r--train_lora.py27
1 files changed, 19 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py
index 9cf17c7..d5aa78d 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -512,6 +512,11 @@ def parse_args():
512 help="The weight of prior preservation loss." 512 help="The weight of prior preservation loss."
513 ) 513 )
514 parser.add_argument( 514 parser.add_argument(
515 "--run_pti",
516 action="store_true",
517 help="Whether to run PTI."
518 )
519 parser.add_argument(
515 "--emb_alpha", 520 "--emb_alpha",
516 type=float, 521 type=float,
517 default=1.0, 522 default=1.0,
@@ -714,6 +719,7 @@ def main():
714 embeddings.persist() 719 embeddings.persist()
715 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") 720 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}")
716 721
722 placeholder_tokens = []
717 placeholder_token_ids = [] 723 placeholder_token_ids = []
718 724
719 if args.embeddings_dir is not None: 725 if args.embeddings_dir is not None:
@@ -722,16 +728,18 @@ def main():
722 raise ValueError("--embeddings_dir must point to an existing directory") 728 raise ValueError("--embeddings_dir must point to an existing directory")
723 729
724 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 730 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
731
732 placeholder_tokens = added_tokens
733 placeholder_token_ids = added_ids
734
725 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 735 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
726 736
727 if args.train_dir_embeddings: 737 if args.train_dir_embeddings:
728 args.placeholder_tokens = added_tokens
729 placeholder_token_ids = added_ids
730 print("Training embeddings from embeddings dir") 738 print("Training embeddings from embeddings dir")
731 else: 739 else:
732 embeddings.persist() 740 embeddings.persist()
733 741
734 if not args.train_dir_embeddings: 742 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings:
735 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 743 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
736 tokenizer=tokenizer, 744 tokenizer=tokenizer,
737 embeddings=embeddings, 745 embeddings=embeddings,
@@ -740,8 +748,11 @@ def main():
740 num_vectors=args.num_vectors, 748 num_vectors=args.num_vectors,
741 initializer_noise=args.initializer_noise, 749 initializer_noise=args.initializer_noise,
742 ) 750 )
751
752 placeholder_tokens = args.placeholder_tokens
753
743 stats = list(zip( 754 stats = list(zip(
744 args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids 755 placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids
745 )) 756 ))
746 print(f"Training embeddings: {stats}") 757 print(f"Training embeddings: {stats}")
747 758
@@ -878,7 +889,7 @@ def main():
878 sample_num_batches=args.sample_batches, 889 sample_num_batches=args.sample_batches,
879 sample_num_steps=args.sample_steps, 890 sample_num_steps=args.sample_steps,
880 sample_image_size=args.sample_image_size, 891 sample_image_size=args.sample_image_size,
881 placeholder_tokens=args.placeholder_tokens, 892 placeholder_tokens=placeholder_tokens,
882 placeholder_token_ids=placeholder_token_ids, 893 placeholder_token_ids=placeholder_token_ids,
883 use_emb_decay=args.use_emb_decay, 894 use_emb_decay=args.use_emb_decay,
884 emb_decay_target=args.emb_decay_target, 895 emb_decay_target=args.emb_decay_target,
@@ -924,8 +935,8 @@ def main():
924 # PTI 935 # PTI
925 # -------------------------------------------------------------------------------- 936 # --------------------------------------------------------------------------------
926 937
927 if len(args.placeholder_tokens) != 0: 938 if args.run_pti and len(placeholder_tokens) != 0:
928 filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] 939 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens]
929 940
930 pti_datamodule = create_datamodule( 941 pti_datamodule = create_datamodule(
931 valid_set_size=0, 942 valid_set_size=0,
@@ -1094,7 +1105,7 @@ def main():
1094 group_labels.append("unet") 1105 group_labels.append("unet")
1095 1106
1096 if training_iter < args.train_text_encoder_cycles: 1107 if training_iter < args.train_text_encoder_cycles:
1097 # if len(args.placeholder_tokens) != 0: 1108 # if len(placeholder_tokens) != 0:
1098 # params_to_optimize.append({ 1109 # params_to_optimize.append({
1099 # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), 1110 # "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
1100 # "lr": learning_rate_emb, 1111 # "lr": learning_rate_emb,