summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 8ac70e8..4c1ec31 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -432,7 +432,7 @@ def main():
432 args = parse_args() 432 args = parse_args()
433 433
434 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 434 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
435 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) 435 output_dir = Path(args.output_dir) / slugify(args.project) / now
436 output_dir.mkdir(parents=True, exist_ok=True) 436 output_dir.mkdir(parents=True, exist_ok=True)
437 437
438 accelerator = Accelerator( 438 accelerator = Accelerator(
@@ -448,7 +448,7 @@ def main():
448 elif args.mixed_precision == "bf16": 448 elif args.mixed_precision == "bf16":
449 weight_dtype = torch.bfloat16 449 weight_dtype = torch.bfloat16
450 450
451 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 451 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
452 452
453 if args.seed is None: 453 if args.seed is None:
454 args.seed = torch.random.seed() >> 32 454 args.seed = torch.random.seed() >> 32
@@ -513,8 +513,8 @@ def main():
513 prior_loss_weight=args.prior_loss_weight, 513 prior_loss_weight=args.prior_loss_weight,
514 ) 514 )
515 515
516 checkpoint_output_dir = output_dir.joinpath("model") 516 checkpoint_output_dir = output_dir / "model"
517 sample_output_dir = output_dir.joinpath(f"samples") 517 sample_output_dir = output_dir / "samples"
518 518
519 datamodule = VlpnDataModule( 519 datamodule = VlpnDataModule(
520 data_file=args.train_data_file, 520 data_file=args.train_data_file,
@@ -596,7 +596,7 @@ def main():
596 sample_image_size=args.sample_image_size, 596 sample_image_size=args.sample_image_size,
597 ) 597 )
598 598
599 plot_metrics(metrics, output_dir.joinpath("lr.png")) 599 plot_metrics(metrics, output_dir / "lr.png")
600 600
601 601
602if __name__ == "__main__": 602if __name__ == "__main__":