summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py23
1 files changed, 16 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py
index 0d8ee23..29e40b2 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -919,6 +919,8 @@ def main():
919 args.num_train_steps / len(lora_datamodule.train_dataset) 919 args.num_train_steps / len(lora_datamodule.train_dataset)
920 ) * args.gradient_accumulation_steps 920 ) * args.gradient_accumulation_steps
921 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) 921 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps))
922 num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps)
923 num_train_steps = num_training_steps_per_epoch * num_train_epochs
922 if args.sample_num is not None: 924 if args.sample_num is not None:
923 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) 925 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
924 926
@@ -956,15 +958,19 @@ def main():
956 958
957 training_iter = 0 959 training_iter = 0
958 960
961 lora_project = "lora"
962
963 if accelerator.is_main_process:
964 accelerator.init_trackers(lora_project)
965
959 while True: 966 while True:
960 training_iter += 1 967 if training_iter >= args.auto_cycles:
961 if training_iter > args.auto_cycles:
962 response = input("Run another cycle? [y/n] ") 968 response = input("Run another cycle? [y/n] ")
963 if response.lower().strip() == "n": 969 if response.lower().strip() == "n":
964 break 970 break
965 971
966 print("") 972 print("")
967 print(f"============ LoRA cycle {training_iter} ============") 973 print(f"============ LoRA cycle {training_iter + 1} ============")
968 print("") 974 print("")
969 975
970 lora_optimizer = create_optimizer(params_to_optimize) 976 lora_optimizer = create_optimizer(params_to_optimize)
@@ -976,19 +982,18 @@ def main():
976 train_epochs=num_train_epochs, 982 train_epochs=num_train_epochs,
977 ) 983 )
978 984
979 lora_project = f"lora_{training_iter}" 985 lora_checkpoint_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "model"
980 lora_checkpoint_output_dir = output_dir / lora_project / "model" 986 lora_sample_output_dir = output_dir / lora_project / f"{training_iter + 1}" / "samples"
981 lora_sample_output_dir = output_dir / lora_project / "samples"
982 987
983 trainer( 988 trainer(
984 strategy=lora_strategy, 989 strategy=lora_strategy,
985 project=lora_project,
986 train_dataloader=lora_datamodule.train_dataloader, 990 train_dataloader=lora_datamodule.train_dataloader,
987 val_dataloader=lora_datamodule.val_dataloader, 991 val_dataloader=lora_datamodule.val_dataloader,
988 optimizer=lora_optimizer, 992 optimizer=lora_optimizer,
989 lr_scheduler=lora_lr_scheduler, 993 lr_scheduler=lora_lr_scheduler,
990 num_train_epochs=num_train_epochs, 994 num_train_epochs=num_train_epochs,
991 gradient_accumulation_steps=args.gradient_accumulation_steps, 995 gradient_accumulation_steps=args.gradient_accumulation_steps,
996 global_step_offset=training_iter * num_train_steps,
992 # -- 997 # --
993 group_labels=group_labels, 998 group_labels=group_labels,
994 sample_output_dir=lora_sample_output_dir, 999 sample_output_dir=lora_sample_output_dir,
@@ -996,6 +1001,10 @@ def main():
996 sample_frequency=lora_sample_frequency, 1001 sample_frequency=lora_sample_frequency,
997 ) 1002 )
998 1003
1004 training_iter += 1
1005
1006 accelerator.end_training()
1007
999 1008
1000if __name__ == "__main__": 1009if __name__ == "__main__":
1001 main() 1010 main()