diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py index 546aaff..34a701b 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -166,9 +166,10 @@ def save_samples( | |||
166 | guidance_scale=guidance_scale, | 166 | guidance_scale=guidance_scale, |
167 | sag_scale=0, | 167 | sag_scale=0, |
168 | num_inference_steps=num_steps, | 168 | num_inference_steps=num_steps, |
169 | output_type="pt", | ||
169 | ).images | 170 | ).images |
170 | 171 | ||
171 | all_samples.append(torch.from_numpy(samples)) | 172 | all_samples.append(samples) |
172 | 173 | ||
173 | all_samples = torch.cat(all_samples) | 174 | all_samples = torch.cat(all_samples) |
174 | 175 | ||
@@ -177,9 +178,9 @@ def save_samples( | |||
177 | # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") | 178 | # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") |
178 | pass | 179 | pass |
179 | 180 | ||
180 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 181 | image_grid = make_grid(all_samples, grid_cols) |
181 | image_grid = pipeline.numpy_to_pil( | 182 | image_grid = pipeline.image_processor.numpy_to_pil( |
182 | image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() | 183 | pipeline.image_processor.pt_to_numpy(image_grid.unsqueeze(0)) |
183 | )[0] | 184 | )[0] |
184 | image_grid.save(file_path, quality=85) | 185 | image_grid.save(file_path, quality=85) |
185 | 186 | ||