diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-24 18:59:29 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-24 18:59:29 +0200 | 
| commit | 92231f9f85a4725daa8e619804d8c0aca3a7043e (patch) | |
| tree | b149318d5100edfdbf17b436e1ef566c77c6cc5f | |
| parent | Fixes (diff) | |
| download | textual-inversion-diff-92231f9f85a4725daa8e619804d8c0aca3a7043e.tar.gz textual-inversion-diff-92231f9f85a4725daa8e619804d8c0aca3a7043e.tar.bz2 textual-inversion-diff-92231f9f85a4725daa8e619804d8c0aca3a7043e.zip | |
Fix
| -rw-r--r-- | train_dreambooth.py | 2 | ||||
| -rw-r--r-- | training/functional.py | 9 | 
2 files changed, 6 insertions, 5 deletions
| diff --git a/train_dreambooth.py b/train_dreambooth.py index be4da1a..84197c8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -462,7 +462,7 @@ def parse_args(): | |||
| 462 | parser.add_argument( | 462 | parser.add_argument( | 
| 463 | "--sample_steps", | 463 | "--sample_steps", | 
| 464 | type=int, | 464 | type=int, | 
| 465 | default=15, | 465 | default=10, | 
| 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 
| 467 | ) | 467 | ) | 
| 468 | parser.add_argument( | 468 | parser.add_argument( | 
| diff --git a/training/functional.py b/training/functional.py index 546aaff..34a701b 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -166,9 +166,10 @@ def save_samples( | |||
| 166 | guidance_scale=guidance_scale, | 166 | guidance_scale=guidance_scale, | 
| 167 | sag_scale=0, | 167 | sag_scale=0, | 
| 168 | num_inference_steps=num_steps, | 168 | num_inference_steps=num_steps, | 
| 169 | output_type="pt", | ||
| 169 | ).images | 170 | ).images | 
| 170 | 171 | ||
| 171 | all_samples.append(torch.from_numpy(samples)) | 172 | all_samples.append(samples) | 
| 172 | 173 | ||
| 173 | all_samples = torch.cat(all_samples) | 174 | all_samples = torch.cat(all_samples) | 
| 174 | 175 | ||
| @@ -177,9 +178,9 @@ def save_samples( | |||
| 177 | # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") | 178 | # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") | 
| 178 | pass | 179 | pass | 
| 179 | 180 | ||
| 180 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 181 | image_grid = make_grid(all_samples, grid_cols) | 
| 181 | image_grid = pipeline.numpy_to_pil( | 182 | image_grid = pipeline.image_processor.numpy_to_pil( | 
| 182 | image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() | 183 | pipeline.image_processor.pt_to_numpy(image_grid.unsqueeze(0)) | 
| 183 | )[0] | 184 | )[0] | 
| 184 | image_grid.save(file_path, quality=85) | 185 | image_grid.save(file_path, quality=85) | 
| 185 | 186 | ||
