diff options
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a198cf6..bfecd1c 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -234,7 +234,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
234 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | 234 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) |
235 | elif isinstance(latents, PIL.Image.Image): | 235 | elif isinstance(latents, PIL.Image.Image): |
236 | latents = preprocess(latents, width, height) | 236 | latents = preprocess(latents, width, height) |
237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 237 | latents = latents.to(device=self.device, dtype=latents_dtype) |
238 | latent_dist = self.vae.encode(latents).latent_dist | ||
238 | latents = latent_dist.sample(generator=generator) | 239 | latents = latent_dist.sample(generator=generator) |
239 | latents = 0.18215 * latents | 240 | latents = 0.18215 * latents |
240 | 241 | ||
@@ -249,7 +250,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
249 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | 250 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) |
250 | 251 | ||
251 | # add noise to latents using the timesteps | 252 | # add noise to latents using the timesteps |
252 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | 253 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) |
253 | latents = self.scheduler.add_noise(latents, noise, timesteps) | 254 | latents = self.scheduler.add_noise(latents, noise, timesteps) |
254 | else: | 255 | else: |
255 | if latents.shape != latents_shape: | 256 | if latents.shape != latents_shape: |