From 94b676d91382267e7429bd68362019868affd9d1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 13 Feb 2023 17:19:18 +0100 Subject: Update --- train_lora.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 5fd05cc..a8c1cf6 100644 --- a/train_lora.py +++ b/train_lora.py @@ -392,7 +392,7 @@ def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir = Path(args.output_dir) / slugify(args.project) / now output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -408,7 +408,7 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 @@ -489,8 +489,8 @@ def main(): prior_loss_weight=args.prior_loss_weight, ) - checkpoint_output_dir = output_dir.joinpath("model") - sample_output_dir = output_dir.joinpath(f"samples") + checkpoint_output_dir = output_dir / "model" + sample_output_dir = output_dir/"samples" datamodule = VlpnDataModule( data_file=args.train_data_file, @@ -562,7 +562,7 @@ def main(): sample_image_size=args.sample_image_size, ) - plot_metrics(metrics, output_dir.joinpath("lr.png")) + plot_metrics(metrics, output_dir/"lr.png") if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf