diff options
| -rw-r--r-- | train_lora.py | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py index 9cf17c7..d5aa78d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -512,6 +512,11 @@ def parse_args(): | |||
| 512 | help="The weight of prior preservation loss." | 512 | help="The weight of prior preservation loss." |
| 513 | ) | 513 | ) |
| 514 | parser.add_argument( | 514 | parser.add_argument( |
| 515 | "--run_pti", | ||
| 516 | action="store_true", | ||
| 517 | help="Whether to run PTI." | ||
| 518 | ) | ||
| 519 | parser.add_argument( | ||
| 515 | "--emb_alpha", | 520 | "--emb_alpha", |
| 516 | type=float, | 521 | type=float, |
| 517 | default=1.0, | 522 | default=1.0, |
| @@ -714,6 +719,7 @@ def main(): | |||
| 714 | embeddings.persist() | 719 | embeddings.persist() |
| 715 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 720 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") |
| 716 | 721 | ||
| 722 | placeholder_tokens = [] | ||
| 717 | placeholder_token_ids = [] | 723 | placeholder_token_ids = [] |
| 718 | 724 | ||
| 719 | if args.embeddings_dir is not None: | 725 | if args.embeddings_dir is not None: |
| @@ -722,16 +728,18 @@ def main(): | |||
| 722 | raise ValueError("--embeddings_dir must point to an existing directory") | 728 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 723 | 729 | ||
| 724 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 730 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 731 | |||
| 732 | placeholder_tokens = added_tokens | ||
| 733 | placeholder_token_ids = added_ids | ||
| 734 | |||
| 725 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 735 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 726 | 736 | ||
| 727 | if args.train_dir_embeddings: | 737 | if args.train_dir_embeddings: |
| 728 | args.placeholder_tokens = added_tokens | ||
| 729 | placeholder_token_ids = added_ids | ||
| 730 | print("Training embeddings from embeddings dir") | 738 | print("Training embeddings from embeddings dir") |
| 731 | else: | 739 | else: |
| 732 | embeddings.persist() | 740 | embeddings.persist() |
| 733 | 741 | ||
| 734 | if not args.train_dir_embeddings: | 742 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
| 735 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 743 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 736 | tokenizer=tokenizer, | 744 | tokenizer=tokenizer, |
| 737 | embeddings=embeddings, | 745 | embeddings=embeddings, |
| @@ -740,8 +748,11 @@ def main(): | |||
| 740 | num_vectors=args.num_vectors, | 748 | num_vectors=args.num_vectors, |
| 741 | initializer_noise=args.initializer_noise, | 749 | initializer_noise=args.initializer_noise, |
| 742 | ) | 750 | ) |
| 751 | |||
| 752 | placeholder_tokens = args.placeholder_tokens | ||
| 753 | |||
| 743 | stats = list(zip( | 754 | stats = list(zip( |
| 744 | args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | 755 | placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids |
| 745 | )) | 756 | )) |
| 746 | print(f"Training embeddings: {stats}") | 757 | print(f"Training embeddings: {stats}") |
| 747 | 758 | ||
| @@ -878,7 +889,7 @@ def main(): | |||
| 878 | sample_num_batches=args.sample_batches, | 889 | sample_num_batches=args.sample_batches, |
| 879 | sample_num_steps=args.sample_steps, | 890 | sample_num_steps=args.sample_steps, |
| 880 | sample_image_size=args.sample_image_size, | 891 | sample_image_size=args.sample_image_size, |
| 881 | placeholder_tokens=args.placeholder_tokens, | 892 | placeholder_tokens=placeholder_tokens, |
| 882 | placeholder_token_ids=placeholder_token_ids, | 893 | placeholder_token_ids=placeholder_token_ids, |
| 883 | use_emb_decay=args.use_emb_decay, | 894 | use_emb_decay=args.use_emb_decay, |
| 884 | emb_decay_target=args.emb_decay_target, | 895 | emb_decay_target=args.emb_decay_target, |
| @@ -924,8 +935,8 @@ def main(): | |||
| 924 | # PTI | 935 | # PTI |
| 925 | # -------------------------------------------------------------------------------- | 936 | # -------------------------------------------------------------------------------- |
| 926 | 937 | ||
| 927 | if len(args.placeholder_tokens) != 0: | 938 | if args.run_pti and len(placeholder_tokens) != 0: |
| 928 | filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] | 939 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] |
| 929 | 940 | ||
| 930 | pti_datamodule = create_datamodule( | 941 | pti_datamodule = create_datamodule( |
| 931 | valid_set_size=0, | 942 | valid_set_size=0, |
| @@ -1094,7 +1105,7 @@ def main(): | |||
| 1094 | group_labels.append("unet") | 1105 | group_labels.append("unet") |
| 1095 | 1106 | ||
| 1096 | if training_iter < args.train_text_encoder_cycles: | 1107 | if training_iter < args.train_text_encoder_cycles: |
| 1097 | # if len(args.placeholder_tokens) != 0: | 1108 | # if len(placeholder_tokens) != 0: |
| 1098 | # params_to_optimize.append({ | 1109 | # params_to_optimize.append({ |
| 1099 | # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | 1110 | # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 1100 | # "lr": learning_rate_emb, | 1111 | # "lr": learning_rate_emb, |
