From 185c6b520d2136c87b122b89b52a0cc013240c6e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 24 Mar 2023 11:50:22 +0100 Subject: Fixed Lora training perf issue --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index dbd262f..ea2a656 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -375,8 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample - image = self.vae.decode(latents).sample + image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() -- cgit v1.2.3-54-g00ecf