summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-28 18:44:30 +0200
committerVolpeon <git@volpeon.ink>2022-09-28 18:44:30 +0200
commita2ccefb05897a8cb32749bc9d83ff9e2b1b8499e (patch)
tree3f790955ff8ccc75f107f00edc4dc976fd188abe
parentBatches of size 1 cause error: Expected query.is_contiguous() to be true, but... (diff)
downloadtextual-inversion-diff-a2ccefb05897a8cb32749bc9d83ff9e2b1b8499e.tar.gz
textual-inversion-diff-a2ccefb05897a8cb32749bc9d83ff9e2b1b8499e.tar.bz2
textual-inversion-diff-a2ccefb05897a8cb32749bc9d83ff9e2b1b8499e.zip
Better sample file structure
-rw-r--r--dreambooth.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 0c58ab5..39c4851 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -402,7 +402,7 @@ class Checkpointer:
402 ) 402 )
403 403
404 all_samples = [] 404 all_samples = []
405 filename = f"step_{step}_val_stable.png" 405 file_path = samples_path.joinpath("stable", f"step_{step}.png")
406 406
407 data_enum = enumerate(val_data) 407 data_enum = enumerate(val_data)
408 408
@@ -427,7 +427,7 @@ class Checkpointer:
427 del samples 427 del samples
428 428
429 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) 429 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
430 image_grid.save(f"{samples_path}/{filename}") 430 image_grid.save(file_path)
431 431
432 del all_samples 432 del all_samples
433 del image_grid 433 del image_grid
@@ -435,7 +435,7 @@ class Checkpointer:
435 435
436 for data, pool in [(val_data, "val"), (train_data, "train")]: 436 for data, pool in [(val_data, "val"), (train_data, "train")]:
437 all_samples = [] 437 all_samples = []
438 filename = f"step_{step}_{pool}.png" 438 file_path = samples_path.joinpath(pool, f"step_{step}.png")
439 439
440 data_enum = enumerate(data) 440 data_enum = enumerate(data)
441 441
@@ -458,7 +458,7 @@ class Checkpointer:
458 del samples 458 del samples
459 459
460 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 460 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
461 image_grid.save(f"{samples_path}/{filename}") 461 image_grid.save(file_path)
462 462
463 del all_samples 463 del all_samples
464 del image_grid 464 del image_grid
@@ -474,7 +474,7 @@ def main():
474 args = parse_args() 474 args = parse_args()
475 475
476 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 476 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
477 basepath = Path(args.output_dir).joinpath(slugify(args.identifier)).joinpath(now) 477 basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now)
478 basepath.mkdir(parents=True, exist_ok=True) 478 basepath.mkdir(parents=True, exist_ok=True)
479 479
480 accelerator = Accelerator( 480 accelerator = Accelerator(