diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a198cf6..bfecd1c 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -234,7 +234,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 234 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | 234 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) |
| 235 | elif isinstance(latents, PIL.Image.Image): | 235 | elif isinstance(latents, PIL.Image.Image): |
| 236 | latents = preprocess(latents, width, height) | 236 | latents = preprocess(latents, width, height) |
| 237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 237 | latents = latents.to(device=self.device, dtype=latents_dtype) |
| 238 | latent_dist = self.vae.encode(latents).latent_dist | ||
| 238 | latents = latent_dist.sample(generator=generator) | 239 | latents = latent_dist.sample(generator=generator) |
| 239 | latents = 0.18215 * latents | 240 | latents = 0.18215 * latents |
| 240 | 241 | ||
| @@ -249,7 +250,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 249 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | 250 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) |
| 250 | 251 | ||
| 251 | # add noise to latents using the timesteps | 252 | # add noise to latents using the timesteps |
| 252 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | 253 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) |
| 253 | latents = self.scheduler.add_noise(latents, noise, timesteps) | 254 | latents = self.scheduler.add_noise(latents, noise, timesteps) |
| 254 | else: | 255 | else: |
| 255 | if latents.shape != latents_shape: | 256 | if latents.shape != latents_shape: |
