From 602c5ff86ce5d9b8aee545dd243ff04d8bddf405 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Apr 2023 10:13:02 +0200 Subject: Fix --- train_lora.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index f1e7ec7..8dbe45b 100644 --- a/train_lora.py +++ b/train_lora.py @@ -965,14 +965,15 @@ def main(): print(f"============ PTI cycle {training_iter} ============") print("") - pti_output_dir = output_dir / f"pti_{training_iter}" + pti_project = f"pti_{training_iter}" + pti_output_dir = output_dir / pti_project pti_checkpoint_output_dir = pti_output_dir / "model" pti_sample_output_dir = pti_output_dir / "samples" trainer( strategy=lora_strategy, pti_mode=True, - project="pti", + project=pti_project, train_dataloader=pti_datamodule.train_dataloader, val_dataloader=pti_datamodule.val_dataloader, optimizer=pti_optimizer, @@ -1060,13 +1061,14 @@ def main(): print(f"============ LoRA cycle {training_iter} ============") print("") - lora_output_dir = output_dir / f"lora_{training_iter}" + lora_project = f"lora_{training_iter}" + lora_output_dir = output_dir / lora_project lora_checkpoint_output_dir = lora_output_dir / "model" lora_sample_output_dir = lora_output_dir / "samples" trainer( strategy=lora_strategy, - project=f"lora_{training_iter}", + project=lora_project, train_dataloader=lora_datamodule.train_dataloader, val_dataloader=lora_datamodule.val_dataloader, optimizer=lora_optimizer, -- cgit v1.2.3-54-g00ecf