diff options
-rw-r--r-- | train_lora.py | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py index 9cf17c7..d5aa78d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -512,6 +512,11 @@ def parse_args(): | |||
512 | help="The weight of prior preservation loss." | 512 | help="The weight of prior preservation loss." |
513 | ) | 513 | ) |
514 | parser.add_argument( | 514 | parser.add_argument( |
515 | "--run_pti", | ||
516 | action="store_true", | ||
517 | help="Whether to run PTI." | ||
518 | ) | ||
519 | parser.add_argument( | ||
515 | "--emb_alpha", | 520 | "--emb_alpha", |
516 | type=float, | 521 | type=float, |
517 | default=1.0, | 522 | default=1.0, |
@@ -714,6 +719,7 @@ def main(): | |||
714 | embeddings.persist() | 719 | embeddings.persist() |
715 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 720 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") |
716 | 721 | ||
722 | placeholder_tokens = [] | ||
717 | placeholder_token_ids = [] | 723 | placeholder_token_ids = [] |
718 | 724 | ||
719 | if args.embeddings_dir is not None: | 725 | if args.embeddings_dir is not None: |
@@ -722,16 +728,18 @@ def main(): | |||
722 | raise ValueError("--embeddings_dir must point to an existing directory") | 728 | raise ValueError("--embeddings_dir must point to an existing directory") |
723 | 729 | ||
724 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 730 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
731 | |||
732 | placeholder_tokens = added_tokens | ||
733 | placeholder_token_ids = added_ids | ||
734 | |||
725 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 735 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
726 | 736 | ||
727 | if args.train_dir_embeddings: | 737 | if args.train_dir_embeddings: |
728 | args.placeholder_tokens = added_tokens | ||
729 | placeholder_token_ids = added_ids | ||
730 | print("Training embeddings from embeddings dir") | 738 | print("Training embeddings from embeddings dir") |
731 | else: | 739 | else: |
732 | embeddings.persist() | 740 | embeddings.persist() |
733 | 741 | ||
734 | if not args.train_dir_embeddings: | 742 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
735 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 743 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
736 | tokenizer=tokenizer, | 744 | tokenizer=tokenizer, |
737 | embeddings=embeddings, | 745 | embeddings=embeddings, |
@@ -740,8 +748,11 @@ def main(): | |||
740 | num_vectors=args.num_vectors, | 748 | num_vectors=args.num_vectors, |
741 | initializer_noise=args.initializer_noise, | 749 | initializer_noise=args.initializer_noise, |
742 | ) | 750 | ) |
751 | |||
752 | placeholder_tokens = args.placeholder_tokens | ||
753 | |||
743 | stats = list(zip( | 754 | stats = list(zip( |
744 | args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | 755 | placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids |
745 | )) | 756 | )) |
746 | print(f"Training embeddings: {stats}") | 757 | print(f"Training embeddings: {stats}") |
747 | 758 | ||
@@ -878,7 +889,7 @@ def main(): | |||
878 | sample_num_batches=args.sample_batches, | 889 | sample_num_batches=args.sample_batches, |
879 | sample_num_steps=args.sample_steps, | 890 | sample_num_steps=args.sample_steps, |
880 | sample_image_size=args.sample_image_size, | 891 | sample_image_size=args.sample_image_size, |
881 | placeholder_tokens=args.placeholder_tokens, | 892 | placeholder_tokens=placeholder_tokens, |
882 | placeholder_token_ids=placeholder_token_ids, | 893 | placeholder_token_ids=placeholder_token_ids, |
883 | use_emb_decay=args.use_emb_decay, | 894 | use_emb_decay=args.use_emb_decay, |
884 | emb_decay_target=args.emb_decay_target, | 895 | emb_decay_target=args.emb_decay_target, |
@@ -924,8 +935,8 @@ def main(): | |||
924 | # PTI | 935 | # PTI |
925 | # -------------------------------------------------------------------------------- | 936 | # -------------------------------------------------------------------------------- |
926 | 937 | ||
927 | if len(args.placeholder_tokens) != 0: | 938 | if args.run_pti and len(placeholder_tokens) != 0: |
928 | filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] | 939 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] |
929 | 940 | ||
930 | pti_datamodule = create_datamodule( | 941 | pti_datamodule = create_datamodule( |
931 | valid_set_size=0, | 942 | valid_set_size=0, |
@@ -1094,7 +1105,7 @@ def main(): | |||
1094 | group_labels.append("unet") | 1105 | group_labels.append("unet") |
1095 | 1106 | ||
1096 | if training_iter < args.train_text_encoder_cycles: | 1107 | if training_iter < args.train_text_encoder_cycles: |
1097 | # if len(args.placeholder_tokens) != 0: | 1108 | # if len(placeholder_tokens) != 0: |
1098 | # params_to_optimize.append({ | 1109 | # params_to_optimize.append({ |
1099 | # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | 1110 | # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), |
1100 | # "lr": learning_rate_emb, | 1111 | # "lr": learning_rate_emb, |