diff options
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index c4f7401..242be29 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -306,13 +306,19 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
306 | 306 | ||
307 | return timesteps | 307 | return timesteps |
308 | 308 | ||
309 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): | 309 | def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None): |
310 | return torch.randn( | 310 | offset = (max_offset * (2 * torch.rand( |
311 | (batch_size, 1, 1, 1), | 311 | (batch_size, 1, 1, 1), |
312 | dtype=dtype, | 312 | dtype=dtype, |
313 | device=device, | 313 | device=device, |
314 | generator=generator | 314 | generator=generator |
315 | ).expand(batch_size, 3, width, height) | 315 | ) - 1)).expand(batch_size, 3, width, height) |
316 | image = (.1 * torch.normal( | ||
317 | mean=offset, | ||
318 | std=1, | ||
319 | generator=generator | ||
320 | )).clamp(-1, 1) | ||
321 | return image | ||
316 | 322 | ||
317 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): | 323 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): |
318 | init_image = init_image.to(device=device, dtype=dtype) | 324 | init_image = init_image.to(device=device, dtype=dtype) |
@@ -376,6 +382,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
376 | eta: float = 0.0, | 382 | eta: float = 0.0, |
377 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 383 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
378 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 384 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
385 | max_image_offset: float = 1.0, | ||
379 | output_type: str = "pil", | 386 | output_type: str = "pil", |
380 | return_dict: bool = True, | 387 | return_dict: bool = True, |
381 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 388 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
@@ -469,6 +476,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
469 | batch_size * num_images_per_prompt, | 476 | batch_size * num_images_per_prompt, |
470 | width, | 477 | width, |
471 | height, | 478 | height, |
479 | max_image_offset, | ||
472 | prompt_embeds.dtype, | 480 | prompt_embeds.dtype, |
473 | device, | 481 | device, |
474 | generator | 482 | generator |