diff options
author | Volpeon <git@volpeon.ink> | 2022-09-28 18:44:30 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-09-28 18:44:30 +0200 |
commit | a2ccefb05897a8cb32749bc9d83ff9e2b1b8499e (patch) | |
tree | 3f790955ff8ccc75f107f00edc4dc976fd188abe | |
parent | Batches of size 1 cause error: Expected query.is_contiguous() to be true, but... (diff) | |
download | textual-inversion-diff-a2ccefb05897a8cb32749bc9d83ff9e2b1b8499e.tar.gz textual-inversion-diff-a2ccefb05897a8cb32749bc9d83ff9e2b1b8499e.tar.bz2 textual-inversion-diff-a2ccefb05897a8cb32749bc9d83ff9e2b1b8499e.zip |
Better sample file structure
-rw-r--r-- | dreambooth.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py index 0c58ab5..39c4851 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -402,7 +402,7 @@ class Checkpointer: | |||
402 | ) | 402 | ) |
403 | 403 | ||
404 | all_samples = [] | 404 | all_samples = [] |
405 | filename = f"step_{step}_val_stable.png" | 405 | file_path = samples_path.joinpath("stable", f"step_{step}.png") |
406 | 406 | ||
407 | data_enum = enumerate(val_data) | 407 | data_enum = enumerate(val_data) |
408 | 408 | ||
@@ -427,7 +427,7 @@ class Checkpointer: | |||
427 | del samples | 427 | del samples |
428 | 428 | ||
429 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 429 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
430 | image_grid.save(f"{samples_path}/{filename}") | 430 | image_grid.save(file_path) |
431 | 431 | ||
432 | del all_samples | 432 | del all_samples |
433 | del image_grid | 433 | del image_grid |
@@ -435,7 +435,7 @@ class Checkpointer: | |||
435 | 435 | ||
436 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 436 | for data, pool in [(val_data, "val"), (train_data, "train")]: |
437 | all_samples = [] | 437 | all_samples = [] |
438 | filename = f"step_{step}_{pool}.png" | 438 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
439 | 439 | ||
440 | data_enum = enumerate(data) | 440 | data_enum = enumerate(data) |
441 | 441 | ||
@@ -458,7 +458,7 @@ class Checkpointer: | |||
458 | del samples | 458 | del samples |
459 | 459 | ||
460 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 460 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
461 | image_grid.save(f"{samples_path}/{filename}") | 461 | image_grid.save(file_path) |
462 | 462 | ||
463 | del all_samples | 463 | del all_samples |
464 | del image_grid | 464 | del image_grid |
@@ -474,7 +474,7 @@ def main(): | |||
474 | args = parse_args() | 474 | args = parse_args() |
475 | 475 | ||
476 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 476 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
477 | basepath = Path(args.output_dir).joinpath(slugify(args.identifier)).joinpath(now) | 477 | basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now) |
478 | basepath.mkdir(parents=True, exist_ok=True) | 478 | basepath.mkdir(parents=True, exist_ok=True) |
479 | 479 | ||
480 | accelerator = Accelerator( | 480 | accelerator = Accelerator( |