summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py
index 5fd05cc..a8c1cf6 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -392,7 +392,7 @@ def main():
392 args = parse_args() 392 args = parse_args()
393 393
394 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 394 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
395 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) 395 output_dir = Path(args.output_dir) / slugify(args.project) / now
396 output_dir.mkdir(parents=True, exist_ok=True) 396 output_dir.mkdir(parents=True, exist_ok=True)
397 397
398 accelerator = Accelerator( 398 accelerator = Accelerator(
@@ -408,7 +408,7 @@ def main():
408 elif args.mixed_precision == "bf16": 408 elif args.mixed_precision == "bf16":
409 weight_dtype = torch.bfloat16 409 weight_dtype = torch.bfloat16
410 410
411 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 411 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
412 412
413 if args.seed is None: 413 if args.seed is None:
414 args.seed = torch.random.seed() >> 32 414 args.seed = torch.random.seed() >> 32
@@ -489,8 +489,8 @@ def main():
489 prior_loss_weight=args.prior_loss_weight, 489 prior_loss_weight=args.prior_loss_weight,
490 ) 490 )
491 491
492 checkpoint_output_dir = output_dir.joinpath("model") 492 checkpoint_output_dir = output_dir / "model"
493 sample_output_dir = output_dir.joinpath(f"samples") 493 sample_output_dir = output_dir/"samples"
494 494
495 datamodule = VlpnDataModule( 495 datamodule = VlpnDataModule(
496 data_file=args.train_data_file, 496 data_file=args.train_data_file,
@@ -562,7 +562,7 @@ def main():
562 sample_image_size=args.sample_image_size, 562 sample_image_size=args.sample_image_size,
563 ) 563 )
564 564
565 plot_metrics(metrics, output_dir.joinpath("lr.png")) 565 plot_metrics(metrics, output_dir/"lr.png")
566 566
567 567
568if __name__ == "__main__": 568if __name__ == "__main__":