diff options
author | Volpeon <git@volpeon.ink> | 2023-06-24 18:59:29 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-24 18:59:29 +0200 |
commit | 92231f9f85a4725daa8e619804d8c0aca3a7043e (patch) | |
tree | b149318d5100edfdbf17b436e1ef566c77c6cc5f | |
parent | Fixes (diff) | |
download | textual-inversion-diff-92231f9f85a4725daa8e619804d8c0aca3a7043e.tar.gz textual-inversion-diff-92231f9f85a4725daa8e619804d8c0aca3a7043e.tar.bz2 textual-inversion-diff-92231f9f85a4725daa8e619804d8c0aca3a7043e.zip |
Fix
-rw-r--r-- | train_dreambooth.py | 2 | ||||
-rw-r--r-- | training/functional.py | 9 |
2 files changed, 6 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index be4da1a..84197c8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -462,7 +462,7 @@ def parse_args(): | |||
462 | parser.add_argument( | 462 | parser.add_argument( |
463 | "--sample_steps", | 463 | "--sample_steps", |
464 | type=int, | 464 | type=int, |
465 | default=15, | 465 | default=10, |
466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
467 | ) | 467 | ) |
468 | parser.add_argument( | 468 | parser.add_argument( |
diff --git a/training/functional.py b/training/functional.py index 546aaff..34a701b 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -166,9 +166,10 @@ def save_samples( | |||
166 | guidance_scale=guidance_scale, | 166 | guidance_scale=guidance_scale, |
167 | sag_scale=0, | 167 | sag_scale=0, |
168 | num_inference_steps=num_steps, | 168 | num_inference_steps=num_steps, |
169 | output_type="pt", | ||
169 | ).images | 170 | ).images |
170 | 171 | ||
171 | all_samples.append(torch.from_numpy(samples)) | 172 | all_samples.append(samples) |
172 | 173 | ||
173 | all_samples = torch.cat(all_samples) | 174 | all_samples = torch.cat(all_samples) |
174 | 175 | ||
@@ -177,9 +178,9 @@ def save_samples( | |||
177 | # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") | 178 | # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") |
178 | pass | 179 | pass |
179 | 180 | ||
180 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 181 | image_grid = make_grid(all_samples, grid_cols) |
181 | image_grid = pipeline.numpy_to_pil( | 182 | image_grid = pipeline.image_processor.numpy_to_pil( |
182 | image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() | 183 | pipeline.image_processor.pt_to_numpy(image_grid.unsqueeze(0)) |
183 | )[0] | 184 | )[0] |
184 | image_grid.save(file_path, quality=85) | 185 | image_grid.save(file_path, quality=85) |
185 | 186 | ||