diff options
author | Volpeon <git@volpeon.ink> | 2023-02-13 17:19:18 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-13 17:19:18 +0100 |
commit | 94b676d91382267e7429bd68362019868affd9d1 (patch) | |
tree | 513697739ab25217cbfcff630299d02b1f6e98c8 /train_dreambooth.py | |
parent | Integrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline (diff) | |
download | textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.gz textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.bz2 textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.zip |
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8ac70e8..4c1ec31 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -432,7 +432,7 @@ def main(): | |||
432 | args = parse_args() | 432 | args = parse_args() |
433 | 433 | ||
434 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 434 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
435 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 435 | output_dir = Path(args.output_dir) / slugify(args.project) / now |
436 | output_dir.mkdir(parents=True, exist_ok=True) | 436 | output_dir.mkdir(parents=True, exist_ok=True) |
437 | 437 | ||
438 | accelerator = Accelerator( | 438 | accelerator = Accelerator( |
@@ -448,7 +448,7 @@ def main(): | |||
448 | elif args.mixed_precision == "bf16": | 448 | elif args.mixed_precision == "bf16": |
449 | weight_dtype = torch.bfloat16 | 449 | weight_dtype = torch.bfloat16 |
450 | 450 | ||
451 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 451 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
452 | 452 | ||
453 | if args.seed is None: | 453 | if args.seed is None: |
454 | args.seed = torch.random.seed() >> 32 | 454 | args.seed = torch.random.seed() >> 32 |
@@ -513,8 +513,8 @@ def main(): | |||
513 | prior_loss_weight=args.prior_loss_weight, | 513 | prior_loss_weight=args.prior_loss_weight, |
514 | ) | 514 | ) |
515 | 515 | ||
516 | checkpoint_output_dir = output_dir.joinpath("model") | 516 | checkpoint_output_dir = output_dir / "model" |
517 | sample_output_dir = output_dir.joinpath(f"samples") | 517 | sample_output_dir = output_dir / "samples" |
518 | 518 | ||
519 | datamodule = VlpnDataModule( | 519 | datamodule = VlpnDataModule( |
520 | data_file=args.train_data_file, | 520 | data_file=args.train_data_file, |
@@ -596,7 +596,7 @@ def main(): | |||
596 | sample_image_size=args.sample_image_size, | 596 | sample_image_size=args.sample_image_size, |
597 | ) | 597 | ) |
598 | 598 | ||
599 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | 599 | plot_metrics(metrics, output_dir / "lr.png") |
600 | 600 | ||
601 | 601 | ||
602 | if __name__ == "__main__": | 602 | if __name__ == "__main__": |