From 7218588c03be1e7fa5566b8836826e6b1c9065d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Dec 2022 08:32:02 +0100 Subject: Fix --- data/csv.py | 4 ++-- training/util.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/data/csv.py b/data/csv.py index 265293b..b45ac77 100644 --- a/data/csv.py +++ b/data/csv.py @@ -106,11 +106,11 @@ class CSVDataModule(): expansions ), prompt_to_keywords( - cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), + cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions ), prompt_to_keywords( - prompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), + nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions ), ) diff --git a/training/util.py b/training/util.py index 000173d..a623dc5 100644 --- a/training/util.py +++ b/training/util.py @@ -74,6 +74,9 @@ class CheckpointerBase: generator=generator, ) + grid_cols = max(self.sample_batch_size, 4) + grid_rows = self.sample_batches * self.sample_batch_size / grid_cols + with torch.autocast("cuda"), torch.inference_mode(): for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] @@ -119,7 +122,7 @@ class CheckpointerBase: del samples - image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) + image_grid = make_grid(all_samples, grid_rows, grid_cols) image_grid.save(file_path, quality=85) del all_samples -- cgit v1.2.3-70-g09d2