From c36b3a32964a5701edd3df1df8533cfdbb89d3cf Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 19 Dec 2022 21:11:21 +0100 Subject: Upstream patches --- .../stable_diffusion/vlpn_stable_diffusion.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) (limited to 'pipelines') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 707b639..a43a8e4 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -236,12 +236,24 @@ class VlpnStableDiffusion(DiffusionPipeline): def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -311,7 +323,7 @@ class VlpnStableDiffusion(DiffusionPipeline): num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, -- cgit v1.2.3-54-g00ecf