diff options
| -rw-r--r-- | data/csv.py | 4 | ||||
| -rw-r--r-- | training/util.py | 5 |
2 files changed, 6 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index 265293b..b45ac77 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -106,11 +106,11 @@ class CSVDataModule(): | |||
| 106 | expansions | 106 | expansions |
| 107 | ), | 107 | ), |
| 108 | prompt_to_keywords( | 108 | prompt_to_keywords( |
| 109 | cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), | 109 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 110 | expansions | 110 | expansions |
| 111 | ), | 111 | ), |
| 112 | prompt_to_keywords( | 112 | prompt_to_keywords( |
| 113 | prompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 113 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 114 | expansions | 114 | expansions |
| 115 | ), | 115 | ), |
| 116 | ) | 116 | ) |
diff --git a/training/util.py b/training/util.py index 000173d..a623dc5 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -74,6 +74,9 @@ class CheckpointerBase: | |||
| 74 | generator=generator, | 74 | generator=generator, |
| 75 | ) | 75 | ) |
| 76 | 76 | ||
| 77 | grid_cols = max(self.sample_batch_size, 4) | ||
| 78 | grid_rows = self.sample_batches * self.sample_batch_size / grid_cols | ||
| 79 | |||
| 77 | with torch.autocast("cuda"), torch.inference_mode(): | 80 | with torch.autocast("cuda"), torch.inference_mode(): |
| 78 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 81 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
| 79 | all_samples = [] | 82 | all_samples = [] |
| @@ -119,7 +122,7 @@ class CheckpointerBase: | |||
| 119 | 122 | ||
| 120 | del samples | 123 | del samples |
| 121 | 124 | ||
| 122 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 125 | image_grid = make_grid(all_samples, grid_rows, grid_cols) |
| 123 | image_grid.save(file_path, quality=85) | 126 | image_grid.save(file_path, quality=85) |
| 124 | 127 | ||
| 125 | del all_samples | 128 | del all_samples |
