diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 53 |
1 files changed, 33 insertions, 20 deletions
diff --git a/train_lora.py b/train_lora.py index 4d4c16a..ba5aee1 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -84,9 +84,9 @@ def parse_args(): | |||
84 | ) | 84 | ) |
85 | parser.add_argument( | 85 | parser.add_argument( |
86 | "--auto_cycles", | 86 | "--auto_cycles", |
87 | type=int, | 87 | type=str, |
88 | default=1, | 88 | default="o", |
89 | help="How many cycles to run automatically." | 89 | help="Cycles to run automatically." |
90 | ) | 90 | ) |
91 | parser.add_argument( | 91 | parser.add_argument( |
92 | "--cycle_decay", | 92 | "--cycle_decay", |
@@ -95,11 +95,6 @@ def parse_args(): | |||
95 | help="Learning rate decay per cycle." | 95 | help="Learning rate decay per cycle." |
96 | ) | 96 | ) |
97 | parser.add_argument( | 97 | parser.add_argument( |
98 | "--cycle_constant", | ||
99 | action="store_true", | ||
100 | help="Use constant LR on cycles > 1." | ||
101 | ) | ||
102 | parser.add_argument( | ||
103 | "--placeholder_tokens", | 98 | "--placeholder_tokens", |
104 | type=str, | 99 | type=str, |
105 | nargs='*', | 100 | nargs='*', |
@@ -920,7 +915,6 @@ def main(): | |||
920 | annealing_func=args.lr_annealing_func, | 915 | annealing_func=args.lr_annealing_func, |
921 | warmup_exp=args.lr_warmup_exp, | 916 | warmup_exp=args.lr_warmup_exp, |
922 | annealing_exp=args.lr_annealing_exp, | 917 | annealing_exp=args.lr_annealing_exp, |
923 | cycles=args.lr_cycles, | ||
924 | end_lr=1e2, | 918 | end_lr=1e2, |
925 | mid_point=args.lr_mid_point, | 919 | mid_point=args.lr_mid_point, |
926 | ) | 920 | ) |
@@ -964,20 +958,38 @@ def main(): | |||
964 | 958 | ||
965 | lora_sample_output_dir = output_dir / lora_project / "samples" | 959 | lora_sample_output_dir = output_dir / lora_project / "samples" |
966 | 960 | ||
961 | auto_cycles = list(args.auto_cycles) | ||
962 | lr_scheduler = args.lr_scheduler | ||
963 | lr_warmup_epochs = args.lr_warmup_epochs | ||
964 | lr_cycles = args.lr_cycles | ||
965 | |||
967 | while True: | 966 | while True: |
968 | if training_iter >= args.auto_cycles: | 967 | if len(auto_cycles) != 0: |
969 | response = input("Run another cycle? [y/n] ") | 968 | response = auto_cycles.pop(0) |
970 | if response.lower().strip() == "n": | 969 | else: |
971 | break | 970 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
971 | |||
972 | if response.lower().strip() == "o": | ||
973 | lr_scheduler = "one_cycle" | ||
974 | lr_warmup_epochs = args.lr_warmup_epochs | ||
975 | lr_cycles = args.lr_cycles | ||
976 | if response.lower().strip() == "w": | ||
977 | lr_scheduler = "constant" | ||
978 | lr_warmup_epochs = num_train_epochs | ||
979 | if response.lower().strip() == "c": | ||
980 | lr_scheduler = "constant" | ||
981 | lr_warmup_epochs = 0 | ||
982 | if response.lower().strip() == "d": | ||
983 | lr_scheduler = "cosine" | ||
984 | lr_warmup_epochs = 0 | ||
985 | lr_cycles = 1 | ||
986 | elif response.lower().strip() == "s": | ||
987 | break | ||
972 | 988 | ||
973 | print("") | 989 | print("") |
974 | print(f"============ LoRA cycle {training_iter + 1} ============") | 990 | print(f"============ LoRA cycle {training_iter + 1} ============") |
975 | print("") | 991 | print("") |
976 | 992 | ||
977 | if args.cycle_constant and training_iter == 1: | ||
978 | args.lr_scheduler = "constant" | ||
979 | args.lr_warmup_epochs = 0 | ||
980 | |||
981 | params_to_optimize = [] | 993 | params_to_optimize = [] |
982 | 994 | ||
983 | if len(args.placeholder_tokens) != 0: | 995 | if len(args.placeholder_tokens) != 0: |
@@ -1012,12 +1024,13 @@ def main(): | |||
1012 | lora_optimizer = create_optimizer(params_to_optimize) | 1024 | lora_optimizer = create_optimizer(params_to_optimize) |
1013 | 1025 | ||
1014 | lora_lr_scheduler = create_lr_scheduler( | 1026 | lora_lr_scheduler = create_lr_scheduler( |
1015 | args.lr_scheduler, | 1027 | lr_scheduler, |
1016 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1028 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
1017 | optimizer=lora_optimizer, | 1029 | optimizer=lora_optimizer, |
1018 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | 1030 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), |
1019 | train_epochs=num_train_epochs, | 1031 | train_epochs=num_train_epochs, |
1020 | warmup_epochs=args.lr_warmup_epochs, | 1032 | cycles=lr_cycles, |
1033 | warmup_epochs=lr_warmup_epochs, | ||
1021 | ) | 1034 | ) |
1022 | 1035 | ||
1023 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" | 1036 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" |
@@ -1031,7 +1044,7 @@ def main(): | |||
1031 | num_train_epochs=num_train_epochs, | 1044 | num_train_epochs=num_train_epochs, |
1032 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1045 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
1033 | global_step_offset=training_iter * num_train_steps, | 1046 | global_step_offset=training_iter * num_train_steps, |
1034 | initial_samples=training_iter == 0, | 1047 | cycle=training_iter, |
1035 | # -- | 1048 | # -- |
1036 | group_labels=group_labels, | 1049 | group_labels=group_labels, |
1037 | sample_output_dir=lora_sample_output_dir, | 1050 | sample_output_dir=lora_sample_output_dir, |