summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py111
1 files changed, 71 insertions, 40 deletions
diff --git a/train_lora.py b/train_lora.py
index d8a4880..f1e7ec7 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -159,6 +159,12 @@ def parse_args():
159 help="Tag dropout probability.", 159 help="Tag dropout probability.",
160 ) 160 )
161 parser.add_argument( 161 parser.add_argument(
162 "--pti_tag_dropout",
163 type=float,
164 default=0,
165 help="Tag dropout probability.",
166 )
167 parser.add_argument(
162 "--no_tag_shuffle", 168 "--no_tag_shuffle",
163 action="store_true", 169 action="store_true",
164 help="Shuffle tags.", 170 help="Shuffle tags.",
@@ -891,7 +897,6 @@ def main():
891 progressive_buckets=args.progressive_buckets, 897 progressive_buckets=args.progressive_buckets,
892 bucket_step_size=args.bucket_step_size, 898 bucket_step_size=args.bucket_step_size,
893 bucket_max_pixels=args.bucket_max_pixels, 899 bucket_max_pixels=args.bucket_max_pixels,
894 dropout=args.tag_dropout,
895 shuffle=not args.no_tag_shuffle, 900 shuffle=not args.no_tag_shuffle,
896 template_key=args.train_data_template, 901 template_key=args.train_data_template,
897 valid_set_size=args.valid_set_size, 902 valid_set_size=args.valid_set_size,
@@ -919,12 +924,9 @@ def main():
919 # -------------------------------------------------------------------------------- 924 # --------------------------------------------------------------------------------
920 925
921 if len(args.placeholder_tokens) != 0: 926 if len(args.placeholder_tokens) != 0:
922 pti_output_dir = output_dir / "pti"
923 pti_checkpoint_output_dir = pti_output_dir / "model"
924 pti_sample_output_dir = pti_output_dir / "samples"
925
926 pti_datamodule = create_datamodule( 927 pti_datamodule = create_datamodule(
927 batch_size=args.pti_batch_size, 928 batch_size=args.pti_batch_size,
929 dropout=args.pti_tag_dropout,
928 filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), 930 filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections),
929 ) 931 )
930 pti_datamodule.setup() 932 pti_datamodule.setup()
@@ -955,22 +957,38 @@ def main():
955 train_epochs=num_pti_epochs, 957 train_epochs=num_pti_epochs,
956 ) 958 )
957 959
958 trainer( 960 continue_training = True
959 strategy=lora_strategy, 961 training_iter = 1
960 pti_mode=True, 962
961 project="pti", 963 while continue_training:
962 train_dataloader=pti_datamodule.train_dataloader, 964 print("")
963 val_dataloader=pti_datamodule.val_dataloader, 965 print(f"============ PTI cycle {training_iter} ============")
964 optimizer=pti_optimizer, 966 print("")
965 lr_scheduler=pti_lr_scheduler, 967
966 num_train_epochs=num_pti_epochs, 968 pti_output_dir = output_dir / f"pti_{training_iter}"
967 gradient_accumulation_steps=args.pti_gradient_accumulation_steps, 969 pti_checkpoint_output_dir = pti_output_dir / "model"
968 # -- 970 pti_sample_output_dir = pti_output_dir / "samples"
969 group_labels=["emb"], 971
970 sample_output_dir=pti_sample_output_dir, 972 trainer(
971 checkpoint_output_dir=pti_checkpoint_output_dir, 973 strategy=lora_strategy,
972 sample_frequency=pti_sample_frequency, 974 pti_mode=True,
973 ) 975 project="pti",
976 train_dataloader=pti_datamodule.train_dataloader,
977 val_dataloader=pti_datamodule.val_dataloader,
978 optimizer=pti_optimizer,
979 lr_scheduler=pti_lr_scheduler,
980 num_train_epochs=num_pti_epochs,
981 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
982 # --
983 group_labels=["emb"],
984 sample_output_dir=pti_sample_output_dir,
985 checkpoint_output_dir=pti_checkpoint_output_dir,
986 sample_frequency=pti_sample_frequency,
987 )
988
989 response = input("Run another cycle? [y/n] ")
990 continue_training = response.lower().strip() != "n"
991 training_iter += 1
974 992
975 if not args.train_emb: 993 if not args.train_emb:
976 embeddings.persist() 994 embeddings.persist()
@@ -978,12 +996,9 @@ def main():
978 # LORA 996 # LORA
979 # -------------------------------------------------------------------------------- 997 # --------------------------------------------------------------------------------
980 998
981 lora_output_dir = output_dir / "lora"
982 lora_checkpoint_output_dir = lora_output_dir / "model"
983 lora_sample_output_dir = lora_output_dir / "samples"
984
985 lora_datamodule = create_datamodule( 999 lora_datamodule = create_datamodule(
986 batch_size=args.train_batch_size, 1000 batch_size=args.train_batch_size,
1001 dropout=args.tag_dropout,
987 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 1002 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
988 ) 1003 )
989 lora_datamodule.setup() 1004 lora_datamodule.setup()
@@ -1037,21 +1052,37 @@ def main():
1037 train_epochs=num_train_epochs, 1052 train_epochs=num_train_epochs,
1038 ) 1053 )
1039 1054
1040 trainer( 1055 continue_training = True
1041 strategy=lora_strategy, 1056 training_iter = 1
1042 project="lora", 1057
1043 train_dataloader=lora_datamodule.train_dataloader, 1058 while continue_training:
1044 val_dataloader=lora_datamodule.val_dataloader, 1059 print("")
1045 optimizer=lora_optimizer, 1060 print(f"============ LoRA cycle {training_iter} ============")
1046 lr_scheduler=lora_lr_scheduler, 1061 print("")
1047 num_train_epochs=num_train_epochs, 1062
1048 gradient_accumulation_steps=args.gradient_accumulation_steps, 1063 lora_output_dir = output_dir / f"lora_{training_iter}"
1049 # -- 1064 lora_checkpoint_output_dir = lora_output_dir / "model"
1050 group_labels=group_labels, 1065 lora_sample_output_dir = lora_output_dir / "samples"
1051 sample_output_dir=lora_sample_output_dir, 1066
1052 checkpoint_output_dir=lora_checkpoint_output_dir, 1067 trainer(
1053 sample_frequency=lora_sample_frequency, 1068 strategy=lora_strategy,
1054 ) 1069 project=f"lora_{training_iter}",
1070 train_dataloader=lora_datamodule.train_dataloader,
1071 val_dataloader=lora_datamodule.val_dataloader,
1072 optimizer=lora_optimizer,
1073 lr_scheduler=lora_lr_scheduler,
1074 num_train_epochs=num_train_epochs,
1075 gradient_accumulation_steps=args.gradient_accumulation_steps,
1076 # --
1077 group_labels=group_labels,
1078 sample_output_dir=lora_sample_output_dir,
1079 checkpoint_output_dir=lora_checkpoint_output_dir,
1080 sample_frequency=lora_sample_frequency,
1081 )
1082
1083 response = input("Run another cycle? [y/n] ")
1084 continue_training = response.lower().strip() != "n"
1085 training_iter += 1
1055 1086
1056 1087
1057if __name__ == "__main__": 1088if __name__ == "__main__":