summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py53
1 files changed, 33 insertions, 20 deletions
diff --git a/train_lora.py b/train_lora.py
index 4d4c16a..ba5aee1 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -84,9 +84,9 @@ def parse_args():
84 ) 84 )
85 parser.add_argument( 85 parser.add_argument(
86 "--auto_cycles", 86 "--auto_cycles",
87 type=int, 87 type=str,
88 default=1, 88 default="o",
89 help="How many cycles to run automatically." 89 help="Cycles to run automatically."
90 ) 90 )
91 parser.add_argument( 91 parser.add_argument(
92 "--cycle_decay", 92 "--cycle_decay",
@@ -95,11 +95,6 @@ def parse_args():
95 help="Learning rate decay per cycle." 95 help="Learning rate decay per cycle."
96 ) 96 )
97 parser.add_argument( 97 parser.add_argument(
98 "--cycle_constant",
99 action="store_true",
100 help="Use constant LR on cycles > 1."
101 )
102 parser.add_argument(
103 "--placeholder_tokens", 98 "--placeholder_tokens",
104 type=str, 99 type=str,
105 nargs='*', 100 nargs='*',
@@ -920,7 +915,6 @@ def main():
920 annealing_func=args.lr_annealing_func, 915 annealing_func=args.lr_annealing_func,
921 warmup_exp=args.lr_warmup_exp, 916 warmup_exp=args.lr_warmup_exp,
922 annealing_exp=args.lr_annealing_exp, 917 annealing_exp=args.lr_annealing_exp,
923 cycles=args.lr_cycles,
924 end_lr=1e2, 918 end_lr=1e2,
925 mid_point=args.lr_mid_point, 919 mid_point=args.lr_mid_point,
926 ) 920 )
@@ -964,20 +958,38 @@ def main():
964 958
965 lora_sample_output_dir = output_dir / lora_project / "samples" 959 lora_sample_output_dir = output_dir / lora_project / "samples"
966 960
961 auto_cycles = list(args.auto_cycles)
962 lr_scheduler = args.lr_scheduler
963 lr_warmup_epochs = args.lr_warmup_epochs
964 lr_cycles = args.lr_cycles
965
967 while True: 966 while True:
968 if training_iter >= args.auto_cycles: 967 if len(auto_cycles) != 0:
969 response = input("Run another cycle? [y/n] ") 968 response = auto_cycles.pop(0)
970 if response.lower().strip() == "n": 969 else:
971 break 970 response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
971
972 if response.lower().strip() == "o":
973 lr_scheduler = "one_cycle"
974 lr_warmup_epochs = args.lr_warmup_epochs
975 lr_cycles = args.lr_cycles
976 if response.lower().strip() == "w":
977 lr_scheduler = "constant"
978 lr_warmup_epochs = num_train_epochs
979 if response.lower().strip() == "c":
980 lr_scheduler = "constant"
981 lr_warmup_epochs = 0
982 if response.lower().strip() == "d":
983 lr_scheduler = "cosine"
984 lr_warmup_epochs = 0
985 lr_cycles = 1
986 elif response.lower().strip() == "s":
987 break
972 988
973 print("") 989 print("")
974 print(f"============ LoRA cycle {training_iter + 1} ============") 990 print(f"============ LoRA cycle {training_iter + 1} ============")
975 print("") 991 print("")
976 992
977 if args.cycle_constant and training_iter == 1:
978 args.lr_scheduler = "constant"
979 args.lr_warmup_epochs = 0
980
981 params_to_optimize = [] 993 params_to_optimize = []
982 994
983 if len(args.placeholder_tokens) != 0: 995 if len(args.placeholder_tokens) != 0:
@@ -1012,12 +1024,13 @@ def main():
1012 lora_optimizer = create_optimizer(params_to_optimize) 1024 lora_optimizer = create_optimizer(params_to_optimize)
1013 1025
1014 lora_lr_scheduler = create_lr_scheduler( 1026 lora_lr_scheduler = create_lr_scheduler(
1015 args.lr_scheduler, 1027 lr_scheduler,
1016 gradient_accumulation_steps=args.gradient_accumulation_steps, 1028 gradient_accumulation_steps=args.gradient_accumulation_steps,
1017 optimizer=lora_optimizer, 1029 optimizer=lora_optimizer,
1018 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), 1030 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader),
1019 train_epochs=num_train_epochs, 1031 train_epochs=num_train_epochs,
1020 warmup_epochs=args.lr_warmup_epochs, 1032 cycles=lr_cycles,
1033 warmup_epochs=lr_warmup_epochs,
1021 ) 1034 )
1022 1035
1023 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" 1036 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}"
@@ -1031,7 +1044,7 @@ def main():
1031 num_train_epochs=num_train_epochs, 1044 num_train_epochs=num_train_epochs,
1032 gradient_accumulation_steps=args.gradient_accumulation_steps, 1045 gradient_accumulation_steps=args.gradient_accumulation_steps,
1033 global_step_offset=training_iter * num_train_steps, 1046 global_step_offset=training_iter * num_train_steps,
1034 initial_samples=training_iter == 0, 1047 cycle=training_iter,
1035 # -- 1048 # --
1036 group_labels=group_labels, 1049 group_labels=group_labels,
1037 sample_output_dir=lora_sample_output_dir, 1050 sample_output_dir=lora_sample_output_dir,