diff options
-rw-r--r-- | train_lora.py | 111 |
1 files changed, 71 insertions, 40 deletions
diff --git a/train_lora.py b/train_lora.py index d8a4880..f1e7ec7 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -159,6 +159,12 @@ def parse_args(): | |||
159 | help="Tag dropout probability.", | 159 | help="Tag dropout probability.", |
160 | ) | 160 | ) |
161 | parser.add_argument( | 161 | parser.add_argument( |
162 | "--pti_tag_dropout", | ||
163 | type=float, | ||
164 | default=0, | ||
165 | help="Tag dropout probability.", | ||
166 | ) | ||
167 | parser.add_argument( | ||
162 | "--no_tag_shuffle", | 168 | "--no_tag_shuffle", |
163 | action="store_true", | 169 | action="store_true", |
164 | help="Shuffle tags.", | 170 | help="Shuffle tags.", |
@@ -891,7 +897,6 @@ def main(): | |||
891 | progressive_buckets=args.progressive_buckets, | 897 | progressive_buckets=args.progressive_buckets, |
892 | bucket_step_size=args.bucket_step_size, | 898 | bucket_step_size=args.bucket_step_size, |
893 | bucket_max_pixels=args.bucket_max_pixels, | 899 | bucket_max_pixels=args.bucket_max_pixels, |
894 | dropout=args.tag_dropout, | ||
895 | shuffle=not args.no_tag_shuffle, | 900 | shuffle=not args.no_tag_shuffle, |
896 | template_key=args.train_data_template, | 901 | template_key=args.train_data_template, |
897 | valid_set_size=args.valid_set_size, | 902 | valid_set_size=args.valid_set_size, |
@@ -919,12 +924,9 @@ def main(): | |||
919 | # -------------------------------------------------------------------------------- | 924 | # -------------------------------------------------------------------------------- |
920 | 925 | ||
921 | if len(args.placeholder_tokens) != 0: | 926 | if len(args.placeholder_tokens) != 0: |
922 | pti_output_dir = output_dir / "pti" | ||
923 | pti_checkpoint_output_dir = pti_output_dir / "model" | ||
924 | pti_sample_output_dir = pti_output_dir / "samples" | ||
925 | |||
926 | pti_datamodule = create_datamodule( | 927 | pti_datamodule = create_datamodule( |
927 | batch_size=args.pti_batch_size, | 928 | batch_size=args.pti_batch_size, |
929 | dropout=args.pti_tag_dropout, | ||
928 | filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), | 930 | filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), |
929 | ) | 931 | ) |
930 | pti_datamodule.setup() | 932 | pti_datamodule.setup() |
@@ -955,22 +957,38 @@ def main(): | |||
955 | train_epochs=num_pti_epochs, | 957 | train_epochs=num_pti_epochs, |
956 | ) | 958 | ) |
957 | 959 | ||
958 | trainer( | 960 | continue_training = True |
959 | strategy=lora_strategy, | 961 | training_iter = 1 |
960 | pti_mode=True, | 962 | |
961 | project="pti", | 963 | while continue_training: |
962 | train_dataloader=pti_datamodule.train_dataloader, | 964 | print("") |
963 | val_dataloader=pti_datamodule.val_dataloader, | 965 | print(f"============ PTI cycle {training_iter} ============") |
964 | optimizer=pti_optimizer, | 966 | print("") |
965 | lr_scheduler=pti_lr_scheduler, | 967 | |
966 | num_train_epochs=num_pti_epochs, | 968 | pti_output_dir = output_dir / f"pti_{training_iter}" |
967 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 969 | pti_checkpoint_output_dir = pti_output_dir / "model" |
968 | # -- | 970 | pti_sample_output_dir = pti_output_dir / "samples" |
969 | group_labels=["emb"], | 971 | |
970 | sample_output_dir=pti_sample_output_dir, | 972 | trainer( |
971 | checkpoint_output_dir=pti_checkpoint_output_dir, | 973 | strategy=lora_strategy, |
972 | sample_frequency=pti_sample_frequency, | 974 | pti_mode=True, |
973 | ) | 975 | project="pti", |
976 | train_dataloader=pti_datamodule.train_dataloader, | ||
977 | val_dataloader=pti_datamodule.val_dataloader, | ||
978 | optimizer=pti_optimizer, | ||
979 | lr_scheduler=pti_lr_scheduler, | ||
980 | num_train_epochs=num_pti_epochs, | ||
981 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | ||
982 | # -- | ||
983 | group_labels=["emb"], | ||
984 | sample_output_dir=pti_sample_output_dir, | ||
985 | checkpoint_output_dir=pti_checkpoint_output_dir, | ||
986 | sample_frequency=pti_sample_frequency, | ||
987 | ) | ||
988 | |||
989 | response = input("Run another cycle? [y/n] ") | ||
990 | continue_training = response.lower().strip() != "n" | ||
991 | training_iter += 1 | ||
974 | 992 | ||
975 | if not args.train_emb: | 993 | if not args.train_emb: |
976 | embeddings.persist() | 994 | embeddings.persist() |
@@ -978,12 +996,9 @@ def main(): | |||
978 | # LORA | 996 | # LORA |
979 | # -------------------------------------------------------------------------------- | 997 | # -------------------------------------------------------------------------------- |
980 | 998 | ||
981 | lora_output_dir = output_dir / "lora" | ||
982 | lora_checkpoint_output_dir = lora_output_dir / "model" | ||
983 | lora_sample_output_dir = lora_output_dir / "samples" | ||
984 | |||
985 | lora_datamodule = create_datamodule( | 999 | lora_datamodule = create_datamodule( |
986 | batch_size=args.train_batch_size, | 1000 | batch_size=args.train_batch_size, |
1001 | dropout=args.tag_dropout, | ||
987 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 1002 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
988 | ) | 1003 | ) |
989 | lora_datamodule.setup() | 1004 | lora_datamodule.setup() |
@@ -1037,21 +1052,37 @@ def main(): | |||
1037 | train_epochs=num_train_epochs, | 1052 | train_epochs=num_train_epochs, |
1038 | ) | 1053 | ) |
1039 | 1054 | ||
1040 | trainer( | 1055 | continue_training = True |
1041 | strategy=lora_strategy, | 1056 | training_iter = 1 |
1042 | project="lora", | 1057 | |
1043 | train_dataloader=lora_datamodule.train_dataloader, | 1058 | while continue_training: |
1044 | val_dataloader=lora_datamodule.val_dataloader, | 1059 | print("") |
1045 | optimizer=lora_optimizer, | 1060 | print(f"============ LoRA cycle {training_iter} ============") |
1046 | lr_scheduler=lora_lr_scheduler, | 1061 | print("") |
1047 | num_train_epochs=num_train_epochs, | 1062 | |
1048 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1063 | lora_output_dir = output_dir / f"lora_{training_iter}" |
1049 | # -- | 1064 | lora_checkpoint_output_dir = lora_output_dir / "model" |
1050 | group_labels=group_labels, | 1065 | lora_sample_output_dir = lora_output_dir / "samples" |
1051 | sample_output_dir=lora_sample_output_dir, | 1066 | |
1052 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1067 | trainer( |
1053 | sample_frequency=lora_sample_frequency, | 1068 | strategy=lora_strategy, |
1054 | ) | 1069 | project=f"lora_{training_iter}", |
1070 | train_dataloader=lora_datamodule.train_dataloader, | ||
1071 | val_dataloader=lora_datamodule.val_dataloader, | ||
1072 | optimizer=lora_optimizer, | ||
1073 | lr_scheduler=lora_lr_scheduler, | ||
1074 | num_train_epochs=num_train_epochs, | ||
1075 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
1076 | # -- | ||
1077 | group_labels=group_labels, | ||
1078 | sample_output_dir=lora_sample_output_dir, | ||
1079 | checkpoint_output_dir=lora_checkpoint_output_dir, | ||
1080 | sample_frequency=lora_sample_frequency, | ||
1081 | ) | ||
1082 | |||
1083 | response = input("Run another cycle? [y/n] ") | ||
1084 | continue_training = response.lower().strip() != "n" | ||
1085 | training_iter += 1 | ||
1055 | 1086 | ||
1056 | 1087 | ||
1057 | if __name__ == "__main__": | 1088 | if __name__ == "__main__": |