diff options
| -rw-r--r-- | dreambooth.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py index 0c58ab5..39c4851 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -402,7 +402,7 @@ class Checkpointer: | |||
| 402 | ) | 402 | ) |
| 403 | 403 | ||
| 404 | all_samples = [] | 404 | all_samples = [] |
| 405 | filename = f"step_{step}_val_stable.png" | 405 | file_path = samples_path.joinpath("stable", f"step_{step}.png") |
| 406 | 406 | ||
| 407 | data_enum = enumerate(val_data) | 407 | data_enum = enumerate(val_data) |
| 408 | 408 | ||
| @@ -427,7 +427,7 @@ class Checkpointer: | |||
| 427 | del samples | 427 | del samples |
| 428 | 428 | ||
| 429 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 429 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
| 430 | image_grid.save(f"{samples_path}/{filename}") | 430 | image_grid.save(file_path) |
| 431 | 431 | ||
| 432 | del all_samples | 432 | del all_samples |
| 433 | del image_grid | 433 | del image_grid |
| @@ -435,7 +435,7 @@ class Checkpointer: | |||
| 435 | 435 | ||
| 436 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 436 | for data, pool in [(val_data, "val"), (train_data, "train")]: |
| 437 | all_samples = [] | 437 | all_samples = [] |
| 438 | filename = f"step_{step}_{pool}.png" | 438 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
| 439 | 439 | ||
| 440 | data_enum = enumerate(data) | 440 | data_enum = enumerate(data) |
| 441 | 441 | ||
| @@ -458,7 +458,7 @@ class Checkpointer: | |||
| 458 | del samples | 458 | del samples |
| 459 | 459 | ||
| 460 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 460 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
| 461 | image_grid.save(f"{samples_path}/{filename}") | 461 | image_grid.save(file_path) |
| 462 | 462 | ||
| 463 | del all_samples | 463 | del all_samples |
| 464 | del image_grid | 464 | del image_grid |
| @@ -474,7 +474,7 @@ def main(): | |||
| 474 | args = parse_args() | 474 | args = parse_args() |
| 475 | 475 | ||
| 476 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 476 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 477 | basepath = Path(args.output_dir).joinpath(slugify(args.identifier)).joinpath(now) | 477 | basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now) |
| 478 | basepath.mkdir(parents=True, exist_ok=True) | 478 | basepath.mkdir(parents=True, exist_ok=True) |
| 479 | 479 | ||
| 480 | accelerator = Accelerator( | 480 | accelerator = Accelerator( |
