diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index 0d8ee23..29e40b2 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -919,6 +919,8 @@ def main(): | |||
919 | args.num_train_steps / len(lora_datamodule.train_dataset) | 919 | args.num_train_steps / len(lora_datamodule.train_dataset) |
920 | ) * args.gradient_accumulation_steps | 920 | ) * args.gradient_accumulation_steps |
921 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 921 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
922 | num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) | ||
923 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
922 | if args.sample_num is not None: | 924 | if args.sample_num is not None: |
923 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 925 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
924 | 926 | ||
@@ -956,15 +958,19 @@ def main(): | |||
956 | 958 | ||
957 | training_iter = 0 | 959 | training_iter = 0 |
958 | 960 | ||
961 | lora_project = "lora" | ||
962 | |||
963 | if accelerator.is_main_process: | ||
964 | accelerator.init_trackers(lora_project) | ||
965 | |||
959 | while True: | 966 | while True: |
960 | training_iter += 1 | 967 | if training_iter >= args.auto_cycles: |
961 | if training_iter > args.auto_cycles: | ||
962 | response = input("Run another cycle? [y/n] ") | 968 | response = input("Run another cycle? [y/n] ") |
963 | if response.lower().strip() == "n": | 969 | if response.lower().strip() == "n": |
964 | break | 970 | break |
965 | 971 | ||
966 | print("") | 972 | print("") |
967 | print(f"============ LoRA cycle {training_iter} ============") | 973 | print(f"============ LoRA cycle {training_iter + 1} ============") |
968 | print("") | 974 | print("") |
969 | 975 | ||
970 | lora_optimizer = create_optimizer(params_to_optimize) | 976 | lora_optimizer = create_optimizer(params_to_optimize) |
@@ -976,19 +982,18 @@ def main(): | |||
976 | train_epochs=num_train_epochs, | 982 | train_epochs=num_train_epochs, |
977 | ) | 983 | ) |
978 | 984 | ||
979 | lora_project = f"lora_{training_iter}" | 985 | lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model" |
980 | lora_checkpoint_output_dir = output_dir / lora_project / "model" | 986 | lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples" |
981 | lora_sample_output_dir = output_dir / lora_project / "samples" | ||
982 | 987 | ||
983 | trainer( | 988 | trainer( |
984 | strategy=lora_strategy, | 989 | strategy=lora_strategy, |
985 | project=lora_project, | ||
986 | train_dataloader=lora_datamodule.train_dataloader, | 990 | train_dataloader=lora_datamodule.train_dataloader, |
987 | val_dataloader=lora_datamodule.val_dataloader, | 991 | val_dataloader=lora_datamodule.val_dataloader, |
988 | optimizer=lora_optimizer, | 992 | optimizer=lora_optimizer, |
989 | lr_scheduler=lora_lr_scheduler, | 993 | lr_scheduler=lora_lr_scheduler, |
990 | num_train_epochs=num_train_epochs, | 994 | num_train_epochs=num_train_epochs, |
991 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 995 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
996 | global_step_offset=training_iter * num_train_steps, | ||
992 | # -- | 997 | # -- |
993 | group_labels=group_labels, | 998 | group_labels=group_labels, |
994 | sample_output_dir=lora_sample_output_dir, | 999 | sample_output_dir=lora_sample_output_dir, |
@@ -996,6 +1001,10 @@ def main(): | |||
996 | sample_frequency=lora_sample_frequency, | 1001 | sample_frequency=lora_sample_frequency, |
997 | ) | 1002 | ) |
998 | 1003 | ||
1004 | training_iter += 1 | ||
1005 | |||
1006 | accelerator.end_training() | ||
1007 | |||
999 | 1008 | ||
1000 | if __name__ == "__main__": | 1009 | if __name__ == "__main__": |
1001 | main() | 1010 | main() |