diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py index 5fd05cc..a8c1cf6 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -392,7 +392,7 @@ def main(): | |||
| 392 | args = parse_args() | 392 | args = parse_args() |
| 393 | 393 | ||
| 394 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 394 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 395 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 395 | output_dir = Path(args.output_dir) / slugify(args.project) / now |
| 396 | output_dir.mkdir(parents=True, exist_ok=True) | 396 | output_dir.mkdir(parents=True, exist_ok=True) |
| 397 | 397 | ||
| 398 | accelerator = Accelerator( | 398 | accelerator = Accelerator( |
| @@ -408,7 +408,7 @@ def main(): | |||
| 408 | elif args.mixed_precision == "bf16": | 408 | elif args.mixed_precision == "bf16": |
| 409 | weight_dtype = torch.bfloat16 | 409 | weight_dtype = torch.bfloat16 |
| 410 | 410 | ||
| 411 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 411 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
| 412 | 412 | ||
| 413 | if args.seed is None: | 413 | if args.seed is None: |
| 414 | args.seed = torch.random.seed() >> 32 | 414 | args.seed = torch.random.seed() >> 32 |
| @@ -489,8 +489,8 @@ def main(): | |||
| 489 | prior_loss_weight=args.prior_loss_weight, | 489 | prior_loss_weight=args.prior_loss_weight, |
| 490 | ) | 490 | ) |
| 491 | 491 | ||
| 492 | checkpoint_output_dir = output_dir.joinpath("model") | 492 | checkpoint_output_dir = output_dir / "model" |
| 493 | sample_output_dir = output_dir.joinpath(f"samples") | 493 | sample_output_dir = output_dir/"samples" |
| 494 | 494 | ||
| 495 | datamodule = VlpnDataModule( | 495 | datamodule = VlpnDataModule( |
| 496 | data_file=args.train_data_file, | 496 | data_file=args.train_data_file, |
| @@ -562,7 +562,7 @@ def main(): | |||
| 562 | sample_image_size=args.sample_image_size, | 562 | sample_image_size=args.sample_image_size, |
| 563 | ) | 563 | ) |
| 564 | 564 | ||
| 565 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | 565 | plot_metrics(metrics, output_dir/"lr.png") |
| 566 | 566 | ||
| 567 | 567 | ||
| 568 | if __name__ == "__main__": | 568 | if __name__ == "__main__": |
