From 0767c7bc82645186159965c2a6be4278e33c6721 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 23 Mar 2023 11:07:57 +0100 Subject: Update --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'pipelines') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4505a2a..dbd262f 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -291,7 +291,8 @@ class VlpnStableDiffusion(DiffusionPipeline): else: attention_mask = None - prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) + prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) return prompt_embeds @@ -374,6 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents + # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 -- cgit v1.2.3-54-g00ecf