diff options
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 707b639..a43a8e4 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -236,12 +236,24 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 236 | def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): | 236 | def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): |
| 237 | shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) | 237 | shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) |
| 238 | 238 | ||
| 239 | if isinstance(generator, list) and len(generator) != batch_size: | ||
| 240 | raise ValueError( | ||
| 241 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||
| 242 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||
| 243 | ) | ||
| 244 | |||
| 239 | if latents is None: | 245 | if latents is None: |
| 240 | if device.type == "mps": | 246 | rand_device = "cpu" if device.type == "mps" else device |
| 241 | # randn does not work reproducibly on mps | 247 | |
| 242 | latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) | 248 | if isinstance(generator, list): |
| 249 | shape = (1,) + shape[1:] | ||
| 250 | latents = [ | ||
| 251 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) | ||
| 252 | for i in range(batch_size) | ||
| 253 | ] | ||
| 254 | latents = torch.cat(latents, dim=0).to(device) | ||
| 243 | else: | 255 | else: |
| 244 | latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) | 256 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) |
| 245 | else: | 257 | else: |
| 246 | if latents.shape != shape: | 258 | if latents.shape != shape: |
| 247 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | 259 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") |
| @@ -311,7 +323,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 311 | num_inference_steps: Optional[int] = 50, | 323 | num_inference_steps: Optional[int] = 50, |
| 312 | guidance_scale: Optional[float] = 7.5, | 324 | guidance_scale: Optional[float] = 7.5, |
| 313 | eta: Optional[float] = 0.0, | 325 | eta: Optional[float] = 0.0, |
| 314 | generator: Optional[torch.Generator] = None, | 326 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 315 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 327 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
| 316 | output_type: Optional[str] = "pil", | 328 | output_type: Optional[str] = "pil", |
| 317 | return_dict: bool = True, | 329 | return_dict: bool = True, |
