summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/training/functional.py b/training/functional.py
index 546aaff..34a701b 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -166,9 +166,10 @@ def save_samples(
166 guidance_scale=guidance_scale, 166 guidance_scale=guidance_scale,
167 sag_scale=0, 167 sag_scale=0,
168 num_inference_steps=num_steps, 168 num_inference_steps=num_steps,
169 output_type="pt",
169 ).images 170 ).images
170 171
171 all_samples.append(torch.from_numpy(samples)) 172 all_samples.append(samples)
172 173
173 all_samples = torch.cat(all_samples) 174 all_samples = torch.cat(all_samples)
174 175
@@ -177,9 +178,9 @@ def save_samples(
177 # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") 178 # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC")
178 pass 179 pass
179 180
180 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) 181 image_grid = make_grid(all_samples, grid_cols)
181 image_grid = pipeline.numpy_to_pil( 182 image_grid = pipeline.image_processor.numpy_to_pil(
182 image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() 183 pipeline.image_processor.pt_to_numpy(image_grid.unsqueeze(0))
183 )[0] 184 )[0]
184 image_grid.save(file_path, quality=85) 185 image_grid.save(file_path, quality=85)
185 186