diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index dbd262f..ea2a656 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -375,8 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 375 | 375 | ||
| 376 | def decode_latents(self, latents): | 376 | def decode_latents(self, latents): |
| 377 | latents = 1 / self.vae.config.scaling_factor * latents | 377 | latents = 1 / self.vae.config.scaling_factor * latents |
| 378 | # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample | 378 | image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample |
| 379 | image = self.vae.decode(latents).sample | ||
| 380 | image = (image / 2 + 0.5).clamp(0, 1) | 379 | image = (image / 2 + 0.5).clamp(0, 1) |
| 381 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 380 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
| 382 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 381 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
