From a2ccefb05897a8cb32749bc9d83ff9e2b1b8499e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Sep 2022 18:44:30 +0200 Subject: Better sample file structure --- dreambooth.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 0c58ab5..39c4851 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -402,7 +402,7 @@ class Checkpointer: ) all_samples = [] - filename = f"step_{step}_val_stable.png" + file_path = samples_path.joinpath("stable", f"step_{step}.png") data_enum = enumerate(val_data) @@ -427,7 +427,7 @@ class Checkpointer: del samples image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) - image_grid.save(f"{samples_path}/{filename}") + image_grid.save(file_path) del all_samples del image_grid @@ -435,7 +435,7 @@ class Checkpointer: for data, pool in [(val_data, "val"), (train_data, "train")]: all_samples = [] - filename = f"step_{step}_{pool}.png" + file_path = samples_path.joinpath(pool, f"step_{step}.png") data_enum = enumerate(data) @@ -458,7 +458,7 @@ class Checkpointer: del samples image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) - image_grid.save(f"{samples_path}/{filename}") + image_grid.save(file_path) del all_samples del image_grid @@ -474,7 +474,7 @@ def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(args.identifier)).joinpath(now) + basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( -- cgit v1.2.3-70-g09d2