diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8ac70e8..4c1ec31 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -432,7 +432,7 @@ def main(): | |||
| 432 | args = parse_args() | 432 | args = parse_args() |
| 433 | 433 | ||
| 434 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 434 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 435 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 435 | output_dir = Path(args.output_dir) / slugify(args.project) / now |
| 436 | output_dir.mkdir(parents=True, exist_ok=True) | 436 | output_dir.mkdir(parents=True, exist_ok=True) |
| 437 | 437 | ||
| 438 | accelerator = Accelerator( | 438 | accelerator = Accelerator( |
| @@ -448,7 +448,7 @@ def main(): | |||
| 448 | elif args.mixed_precision == "bf16": | 448 | elif args.mixed_precision == "bf16": |
| 449 | weight_dtype = torch.bfloat16 | 449 | weight_dtype = torch.bfloat16 |
| 450 | 450 | ||
| 451 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 451 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
| 452 | 452 | ||
| 453 | if args.seed is None: | 453 | if args.seed is None: |
| 454 | args.seed = torch.random.seed() >> 32 | 454 | args.seed = torch.random.seed() >> 32 |
| @@ -513,8 +513,8 @@ def main(): | |||
| 513 | prior_loss_weight=args.prior_loss_weight, | 513 | prior_loss_weight=args.prior_loss_weight, |
| 514 | ) | 514 | ) |
| 515 | 515 | ||
| 516 | checkpoint_output_dir = output_dir.joinpath("model") | 516 | checkpoint_output_dir = output_dir / "model" |
| 517 | sample_output_dir = output_dir.joinpath(f"samples") | 517 | sample_output_dir = output_dir / "samples" |
| 518 | 518 | ||
| 519 | datamodule = VlpnDataModule( | 519 | datamodule = VlpnDataModule( |
| 520 | data_file=args.train_data_file, | 520 | data_file=args.train_data_file, |
| @@ -596,7 +596,7 @@ def main(): | |||
| 596 | sample_image_size=args.sample_image_size, | 596 | sample_image_size=args.sample_image_size, |
| 597 | ) | 597 | ) |
| 598 | 598 | ||
| 599 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | 599 | plot_metrics(metrics, output_dir / "lr.png") |
| 600 | 600 | ||
| 601 | 601 | ||
| 602 | if __name__ == "__main__": | 602 | if __name__ == "__main__": |
