summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py10
-rw-r--r--training/functional.py2
2 files changed, 7 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py
index f1e7ec7..8dbe45b 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -965,14 +965,15 @@ def main():
965 print(f"============ PTI cycle {training_iter} ============") 965 print(f"============ PTI cycle {training_iter} ============")
966 print("") 966 print("")
967 967
968 pti_output_dir = output_dir / f"pti_{training_iter}" 968 pti_project = f"pti_{training_iter}"
969 pti_output_dir = output_dir / pti_project
969 pti_checkpoint_output_dir = pti_output_dir / "model" 970 pti_checkpoint_output_dir = pti_output_dir / "model"
970 pti_sample_output_dir = pti_output_dir / "samples" 971 pti_sample_output_dir = pti_output_dir / "samples"
971 972
972 trainer( 973 trainer(
973 strategy=lora_strategy, 974 strategy=lora_strategy,
974 pti_mode=True, 975 pti_mode=True,
975 project="pti", 976 project=pti_project,
976 train_dataloader=pti_datamodule.train_dataloader, 977 train_dataloader=pti_datamodule.train_dataloader,
977 val_dataloader=pti_datamodule.val_dataloader, 978 val_dataloader=pti_datamodule.val_dataloader,
978 optimizer=pti_optimizer, 979 optimizer=pti_optimizer,
@@ -1060,13 +1061,14 @@ def main():
1060 print(f"============ LoRA cycle {training_iter} ============") 1061 print(f"============ LoRA cycle {training_iter} ============")
1061 print("") 1062 print("")
1062 1063
1063 lora_output_dir = output_dir / f"lora_{training_iter}" 1064 lora_project = f"lora_{training_iter}"
1065 lora_output_dir = output_dir / lora_project
1064 lora_checkpoint_output_dir = lora_output_dir / "model" 1066 lora_checkpoint_output_dir = lora_output_dir / "model"
1065 lora_sample_output_dir = lora_output_dir / "samples" 1067 lora_sample_output_dir = lora_output_dir / "samples"
1066 1068
1067 trainer( 1069 trainer(
1068 strategy=lora_strategy, 1070 strategy=lora_strategy,
1069 project=f"lora_{training_iter}", 1071 project=lora_project,
1070 train_dataloader=lora_datamodule.train_dataloader, 1072 train_dataloader=lora_datamodule.train_dataloader,
1071 val_dataloader=lora_datamodule.val_dataloader, 1073 val_dataloader=lora_datamodule.val_dataloader,
1072 optimizer=lora_optimizer, 1074 optimizer=lora_optimizer,
diff --git a/training/functional.py b/training/functional.py
index 71b2fe9..7d49782 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -606,8 +606,8 @@ def train_loop(
606 # Create the pipeline using using the trained modules and save it. 606 # Create the pipeline using using the trained modules and save it.
607 if accelerator.is_main_process: 607 if accelerator.is_main_process:
608 print("Finished!") 608 print("Finished!")
609 on_checkpoint(global_step + global_step_offset, "end")
610 on_sample(global_step + global_step_offset) 609 on_sample(global_step + global_step_offset)
610 on_checkpoint(global_step + global_step_offset, "end")
611 611
612 except KeyboardInterrupt: 612 except KeyboardInterrupt:
613 if accelerator.is_main_process: 613 if accelerator.is_main_process: