From 92231f9f85a4725daa8e619804d8c0aca3a7043e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 18:59:29 +0200 Subject: Fix --- training/functional.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 546aaff..34a701b 100644 --- a/training/functional.py +++ b/training/functional.py @@ -166,9 +166,10 @@ def save_samples( guidance_scale=guidance_scale, sag_scale=0, num_inference_steps=num_steps, + output_type="pt", ).images - all_samples.append(torch.from_numpy(samples)) + all_samples.append(samples) all_samples = torch.cat(all_samples) @@ -177,9 +178,9 @@ def save_samples( # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") pass - image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) - image_grid = pipeline.numpy_to_pil( - image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() + image_grid = make_grid(all_samples, grid_cols) + image_grid = pipeline.image_processor.numpy_to_pil( + pipeline.image_processor.pt_to_numpy(image_grid.unsqueeze(0)) )[0] image_grid.save(file_path, quality=85) -- cgit v1.2.3-70-g09d2