diff options
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index b4c85e9..8fbe5f9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -223,15 +223,16 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
223 | # Unlike in other pipelines, latents need to be generated in the target device | 223 | # Unlike in other pipelines, latents need to be generated in the target device |
224 | # for 1-to-1 results reproducibility with the CompVis implementation. | 224 | # for 1-to-1 results reproducibility with the CompVis implementation. |
225 | # However this currently doesn't work in `mps`. | 225 | # However this currently doesn't work in `mps`. |
226 | latents_device = "cpu" if self.device.type == "mps" else self.device | 226 | latents_dtype = text_embeddings.dtype |
227 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) | 227 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) |
228 | if latents is None: | 228 | if latents is None: |
229 | latents = torch.randn( | 229 | if self.device.type == "mps": |
230 | latents_shape, | 230 | # randn does not exist on mps |
231 | generator=generator, | 231 | latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( |
232 | device=latents_device, | 232 | self.device |
233 | dtype=text_embeddings.dtype, | 233 | ) |
234 | ) | 234 | else: |
235 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | ||
235 | elif isinstance(latents, PIL.Image.Image): | 236 | elif isinstance(latents, PIL.Image.Image): |
236 | latents = preprocess(latents, width, height) | 237 | latents = preprocess(latents, width, height) |
237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 238 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist |
@@ -259,7 +260,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
259 | else: | 260 | else: |
260 | if latents.shape != latents_shape: | 261 | if latents.shape != latents_shape: |
261 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 262 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
262 | latents = latents.to(self.device) | 263 | if latents.device != self.device: |
264 | raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") | ||
263 | 265 | ||
264 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | 266 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas |
265 | if ensure_sigma: | 267 | if ensure_sigma: |