summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-24 18:59:29 +0200
committerVolpeon <git@volpeon.ink>2023-06-24 18:59:29 +0200
commit92231f9f85a4725daa8e619804d8c0aca3a7043e (patch)
treeb149318d5100edfdbf17b436e1ef566c77c6cc5f
parentFixes (diff)
downloadtextual-inversion-diff-92231f9f85a4725daa8e619804d8c0aca3a7043e.tar.gz
textual-inversion-diff-92231f9f85a4725daa8e619804d8c0aca3a7043e.tar.bz2
textual-inversion-diff-92231f9f85a4725daa8e619804d8c0aca3a7043e.zip
Fix
-rw-r--r--train_dreambooth.py2
-rw-r--r--training/functional.py9
2 files changed, 6 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index be4da1a..84197c8 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -462,7 +462,7 @@ def parse_args():
462 parser.add_argument( 462 parser.add_argument(
463 "--sample_steps", 463 "--sample_steps",
464 type=int, 464 type=int,
465 default=15, 465 default=10,
466 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 466 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
467 ) 467 )
468 parser.add_argument( 468 parser.add_argument(
diff --git a/training/functional.py b/training/functional.py
index 546aaff..34a701b 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -166,9 +166,10 @@ def save_samples(
166 guidance_scale=guidance_scale, 166 guidance_scale=guidance_scale,
167 sag_scale=0, 167 sag_scale=0,
168 num_inference_steps=num_steps, 168 num_inference_steps=num_steps,
169 output_type="pt",
169 ).images 170 ).images
170 171
171 all_samples.append(torch.from_numpy(samples)) 172 all_samples.append(samples)
172 173
173 all_samples = torch.cat(all_samples) 174 all_samples = torch.cat(all_samples)
174 175
@@ -177,9 +178,9 @@ def save_samples(
177 # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") 178 # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC")
178 pass 179 pass
179 180
180 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) 181 image_grid = make_grid(all_samples, grid_cols)
181 image_grid = pipeline.numpy_to_pil( 182 image_grid = pipeline.image_processor.numpy_to_pil(
182 image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() 183 pipeline.image_processor.pt_to_numpy(image_grid.unsqueeze(0))
183 )[0] 184 )[0]
184 image_grid.save(file_path, quality=85) 185 image_grid.save(file_path, quality=85)
185 186