From 64c594869135354a38353551bd58a93e15bd5b85 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Oct 2022 20:57:43 +0200 Subject: Small performance improvements --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) (limited to 'pipelines') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index b4c85e9..8fbe5f9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -223,15 +223,16 @@ class VlpnStableDiffusion(DiffusionPipeline): # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_device = "cpu" if self.device.type == "mps" else self.device + latents_dtype = text_embeddings.dtype latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: - latents = torch.randn( - latents_shape, - generator=generator, - device=latents_device, - dtype=text_embeddings.dtype, - ) + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) elif isinstance(latents, PIL.Image.Image): latents = preprocess(latents, width, height) latent_dist = self.vae.encode(latents.to(self.device)).latent_dist @@ -259,7 +260,8 @@ class VlpnStableDiffusion(DiffusionPipeline): else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + if latents.device != self.device: + raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if ensure_sigma: -- cgit v1.2.3-70-g09d2