summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 4505a2a..dbd262f 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -291,7 +291,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
291 else: 291 else:
292 attention_mask = None 292 attention_mask = None
293 293
294 prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) 294 prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask)
295 prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
295 296
296 return prompt_embeds 297 return prompt_embeds
297 298
@@ -374,6 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
374 375
375 def decode_latents(self, latents): 376 def decode_latents(self, latents):
376 latents = 1 / self.vae.config.scaling_factor * latents 377 latents = 1 / self.vae.config.scaling_factor * latents
378 # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample
377 image = self.vae.decode(latents).sample 379 image = self.vae.decode(latents).sample
378 image = (image / 2 + 0.5).clamp(0, 1) 380 image = (image / 2 + 0.5).clamp(0, 1)
379 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 381 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16