diff options
-rw-r--r-- | dreambooth.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py index 0c58ab5..39c4851 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -402,7 +402,7 @@ class Checkpointer: | |||
402 | ) | 402 | ) |
403 | 403 | ||
404 | all_samples = [] | 404 | all_samples = [] |
405 | filename = f"step_{step}_val_stable.png" | 405 | file_path = samples_path.joinpath("stable", f"step_{step}.png") |
406 | 406 | ||
407 | data_enum = enumerate(val_data) | 407 | data_enum = enumerate(val_data) |
408 | 408 | ||
@@ -427,7 +427,7 @@ class Checkpointer: | |||
427 | del samples | 427 | del samples |
428 | 428 | ||
429 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 429 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
430 | image_grid.save(f"{samples_path}/{filename}") | 430 | image_grid.save(file_path) |
431 | 431 | ||
432 | del all_samples | 432 | del all_samples |
433 | del image_grid | 433 | del image_grid |
@@ -435,7 +435,7 @@ class Checkpointer: | |||
435 | 435 | ||
436 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 436 | for data, pool in [(val_data, "val"), (train_data, "train")]: |
437 | all_samples = [] | 437 | all_samples = [] |
438 | filename = f"step_{step}_{pool}.png" | 438 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
439 | 439 | ||
440 | data_enum = enumerate(data) | 440 | data_enum = enumerate(data) |
441 | 441 | ||
@@ -458,7 +458,7 @@ class Checkpointer: | |||
458 | del samples | 458 | del samples |
459 | 459 | ||
460 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 460 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
461 | image_grid.save(f"{samples_path}/{filename}") | 461 | image_grid.save(file_path) |
462 | 462 | ||
463 | del all_samples | 463 | del all_samples |
464 | del image_grid | 464 | del image_grid |
@@ -474,7 +474,7 @@ def main(): | |||
474 | args = parse_args() | 474 | args = parse_args() |
475 | 475 | ||
476 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 476 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
477 | basepath = Path(args.output_dir).joinpath(slugify(args.identifier)).joinpath(now) | 477 | basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now) |
478 | basepath.mkdir(parents=True, exist_ok=True) | 478 | basepath.mkdir(parents=True, exist_ok=True) |
479 | 479 | ||
480 | accelerator = Accelerator( | 480 | accelerator = Accelerator( |