summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py22
1 files changed, 17 insertions, 5 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 707b639..a43a8e4 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -236,12 +236,24 @@ class VlpnStableDiffusion(DiffusionPipeline):
236 def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): 236 def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None):
237 shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) 237 shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
238 238
239 if isinstance(generator, list) and len(generator) != batch_size:
240 raise ValueError(
241 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
242 f" size of {batch_size}. Make sure the batch size matches the length of the generators."
243 )
244
239 if latents is None: 245 if latents is None:
240 if device.type == "mps": 246 rand_device = "cpu" if device.type == "mps" else device
241 # randn does not work reproducibly on mps 247
242 latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) 248 if isinstance(generator, list):
249 shape = (1,) + shape[1:]
250 latents = [
251 torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
252 for i in range(batch_size)
253 ]
254 latents = torch.cat(latents, dim=0).to(device)
243 else: 255 else:
244 latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) 256 latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
245 else: 257 else:
246 if latents.shape != shape: 258 if latents.shape != shape:
247 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 259 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
@@ -311,7 +323,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
311 num_inference_steps: Optional[int] = 50, 323 num_inference_steps: Optional[int] = 50,
312 guidance_scale: Optional[float] = 7.5, 324 guidance_scale: Optional[float] = 7.5,
313 eta: Optional[float] = 0.0, 325 eta: Optional[float] = 0.0,
314 generator: Optional[torch.Generator] = None, 326 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
315 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 327 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
316 output_type: Optional[str] = "pil", 328 output_type: Optional[str] = "pil",
317 return_dict: bool = True, 329 return_dict: bool = True,