diff options
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index dbd262f..ea2a656 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -375,8 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
375 | 375 | ||
376 | def decode_latents(self, latents): | 376 | def decode_latents(self, latents): |
377 | latents = 1 / self.vae.config.scaling_factor * latents | 377 | latents = 1 / self.vae.config.scaling_factor * latents |
378 | # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample | 378 | image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample |
379 | image = self.vae.decode(latents).sample | ||
380 | image = (image / 2 + 0.5).clamp(0, 1) | 379 | image = (image / 2 + 0.5).clamp(0, 1) |
381 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 380 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
382 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 381 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |