From 185c6b520d2136c87b122b89b52a0cc013240c6e Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Fri, 24 Mar 2023 11:50:22 +0100
Subject: Fixed Lora training perf issue

---
 pipelines/stable_diffusion/vlpn_stable_diffusion.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

(limited to 'pipelines')

diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index dbd262f..ea2a656 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -375,8 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
 
     def decode_latents(self, latents):
         latents = 1 / self.vae.config.scaling_factor * latents
-        # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample
-        image = self.vae.decode(latents).sample
+        image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample
         image = (image / 2 + 0.5).clamp(0, 1)
         # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
         image = image.cpu().permute(0, 2, 3, 1).float().numpy()
-- 
cgit v1.2.3-70-g09d2