diff options
author | Volpeon <git@volpeon.ink> | 2023-03-23 11:07:57 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-23 11:07:57 +0100 |
commit | 0767c7bc82645186159965c2a6be4278e33c6721 (patch) | |
tree | a136470ab85dbb99ab51d9be4a7831fe21612ab3 /pipelines/stable_diffusion | |
parent | Fix (diff) | |
download | textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.gz textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.tar.bz2 textual-inversion-diff-0767c7bc82645186159965c2a6be4278e33c6721.zip |
Update
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4505a2a..dbd262f 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -291,7 +291,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
291 | else: | 291 | else: |
292 | attention_mask = None | 292 | attention_mask = None |
293 | 293 | ||
294 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) | 294 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) |
295 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | ||
295 | 296 | ||
296 | return prompt_embeds | 297 | return prompt_embeds |
297 | 298 | ||
@@ -374,6 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
374 | 375 | ||
375 | def decode_latents(self, latents): | 376 | def decode_latents(self, latents): |
376 | latents = 1 / self.vae.config.scaling_factor * latents | 377 | latents = 1 / self.vae.config.scaling_factor * latents |
378 | # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample | ||
377 | image = self.vae.decode(latents).sample | 379 | image = self.vae.decode(latents).sample |
378 | image = (image / 2 + 0.5).clamp(0, 1) | 380 | image = (image / 2 + 0.5).clamp(0, 1) |
379 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 381 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |