From 92231f9f85a4725daa8e619804d8c0aca3a7043e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 18:59:29 +0200 Subject: Fix --- train_dreambooth.py | 2 +- training/functional.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index be4da1a..84197c8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -462,7 +462,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=15, + default=10, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( diff --git a/training/functional.py b/training/functional.py index 546aaff..34a701b 100644 --- a/training/functional.py +++ b/training/functional.py @@ -166,9 +166,10 @@ def save_samples( guidance_scale=guidance_scale, sag_scale=0, num_inference_steps=num_steps, + output_type="pt", ).images - all_samples.append(torch.from_numpy(samples)) + all_samples.append(samples) all_samples = torch.cat(all_samples) @@ -177,9 +178,9 @@ def save_samples( # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") pass - image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) - image_grid = pipeline.numpy_to_pil( - image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() + image_grid = make_grid(all_samples, grid_cols) + image_grid = pipeline.image_processor.numpy_to_pil( + pipeline.image_processor.pt_to_numpy(image_grid.unsqueeze(0)) )[0] image_grid.save(file_path, quality=85) -- cgit v1.2.3-54-g00ecf