summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index dbd262f..ea2a656 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -375,8 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
375 375
376 def decode_latents(self, latents): 376 def decode_latents(self, latents):
377 latents = 1 / self.vae.config.scaling_factor * latents 377 latents = 1 / self.vae.config.scaling_factor * latents
378 # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample 378 image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample
379 image = self.vae.decode(latents).sample
380 image = (image / 2 + 0.5).clamp(0, 1) 379 image = (image / 2 + 0.5).clamp(0, 1)
381 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 380 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
382 image = image.cpu().permute(0, 2, 3, 1).float().numpy() 381 image = image.cpu().permute(0, 2, 3, 1).float().numpy()