diff options
Diffstat (limited to 'pipelines')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4505a2a..dbd262f 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -291,7 +291,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 291 | else: | 291 | else: |
| 292 | attention_mask = None | 292 | attention_mask = None |
| 293 | 293 | ||
| 294 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) | 294 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) |
| 295 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | ||
| 295 | 296 | ||
| 296 | return prompt_embeds | 297 | return prompt_embeds |
| 297 | 298 | ||
| @@ -374,6 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 374 | 375 | ||
| 375 | def decode_latents(self, latents): | 376 | def decode_latents(self, latents): |
| 376 | latents = 1 / self.vae.config.scaling_factor * latents | 377 | latents = 1 / self.vae.config.scaling_factor * latents |
| 378 | # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample | ||
| 377 | image = self.vae.decode(latents).sample | 379 | image = self.vae.decode(latents).sample |
| 378 | image = (image / 2 + 0.5).clamp(0, 1) | 380 | image = (image / 2 + 0.5).clamp(0, 1) |
| 379 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 381 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
