diff options
-rw-r--r-- | train_lora.py | 10 | ||||
-rw-r--r-- | training/functional.py | 2 |
2 files changed, 7 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py index f1e7ec7..8dbe45b 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -965,14 +965,15 @@ def main(): | |||
965 | print(f"============ PTI cycle {training_iter} ============") | 965 | print(f"============ PTI cycle {training_iter} ============") |
966 | print("") | 966 | print("") |
967 | 967 | ||
968 | pti_output_dir = output_dir / f"pti_{training_iter}" | 968 | pti_project = f"pti_{training_iter}" |
969 | pti_output_dir = output_dir / pti_project | ||
969 | pti_checkpoint_output_dir = pti_output_dir / "model" | 970 | pti_checkpoint_output_dir = pti_output_dir / "model" |
970 | pti_sample_output_dir = pti_output_dir / "samples" | 971 | pti_sample_output_dir = pti_output_dir / "samples" |
971 | 972 | ||
972 | trainer( | 973 | trainer( |
973 | strategy=lora_strategy, | 974 | strategy=lora_strategy, |
974 | pti_mode=True, | 975 | pti_mode=True, |
975 | project="pti", | 976 | project=pti_project, |
976 | train_dataloader=pti_datamodule.train_dataloader, | 977 | train_dataloader=pti_datamodule.train_dataloader, |
977 | val_dataloader=pti_datamodule.val_dataloader, | 978 | val_dataloader=pti_datamodule.val_dataloader, |
978 | optimizer=pti_optimizer, | 979 | optimizer=pti_optimizer, |
@@ -1060,13 +1061,14 @@ def main(): | |||
1060 | print(f"============ LoRA cycle {training_iter} ============") | 1061 | print(f"============ LoRA cycle {training_iter} ============") |
1061 | print("") | 1062 | print("") |
1062 | 1063 | ||
1063 | lora_output_dir = output_dir / f"lora_{training_iter}" | 1064 | lora_project = f"lora_{training_iter}" |
1065 | lora_output_dir = output_dir / lora_project | ||
1064 | lora_checkpoint_output_dir = lora_output_dir / "model" | 1066 | lora_checkpoint_output_dir = lora_output_dir / "model" |
1065 | lora_sample_output_dir = lora_output_dir / "samples" | 1067 | lora_sample_output_dir = lora_output_dir / "samples" |
1066 | 1068 | ||
1067 | trainer( | 1069 | trainer( |
1068 | strategy=lora_strategy, | 1070 | strategy=lora_strategy, |
1069 | project=f"lora_{training_iter}", | 1071 | project=lora_project, |
1070 | train_dataloader=lora_datamodule.train_dataloader, | 1072 | train_dataloader=lora_datamodule.train_dataloader, |
1071 | val_dataloader=lora_datamodule.val_dataloader, | 1073 | val_dataloader=lora_datamodule.val_dataloader, |
1072 | optimizer=lora_optimizer, | 1074 | optimizer=lora_optimizer, |
diff --git a/training/functional.py b/training/functional.py index 71b2fe9..7d49782 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -606,8 +606,8 @@ def train_loop( | |||
606 | # Create the pipeline using using the trained modules and save it. | 606 | # Create the pipeline using using the trained modules and save it. |
607 | if accelerator.is_main_process: | 607 | if accelerator.is_main_process: |
608 | print("Finished!") | 608 | print("Finished!") |
609 | on_checkpoint(global_step + global_step_offset, "end") | ||
610 | on_sample(global_step + global_step_offset) | 609 | on_sample(global_step + global_step_offset) |
610 | on_checkpoint(global_step + global_step_offset, "end") | ||
611 | 611 | ||
612 | except KeyboardInterrupt: | 612 | except KeyboardInterrupt: |
613 | if accelerator.is_main_process: | 613 | if accelerator.is_main_process: |