summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-08 21:56:54 +0200
committerVolpeon <git@volpeon.ink>2022-10-08 21:56:54 +0200
commit6aadb34af4fe5ca2dfc92fae8eee87610a5848ad (patch)
treef490b4794366e78f7b079eb04de1c7c00e17d34a /pipelines/stable_diffusion
parentFix small details (diff)
downloadtextual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.gz
textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.bz2
textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.zip
Update
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index a198cf6..bfecd1c 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -234,7 +234,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
234 latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) 234 latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
235 elif isinstance(latents, PIL.Image.Image): 235 elif isinstance(latents, PIL.Image.Image):
236 latents = preprocess(latents, width, height) 236 latents = preprocess(latents, width, height)
237 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist 237 latents = latents.to(device=self.device, dtype=latents_dtype)
238 latent_dist = self.vae.encode(latents).latent_dist
238 latents = latent_dist.sample(generator=generator) 239 latents = latent_dist.sample(generator=generator)
239 latents = 0.18215 * latents 240 latents = 0.18215 * latents
240 241
@@ -249,7 +250,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
249 timesteps = torch.tensor([timesteps] * batch_size, device=self.device) 250 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
250 251
251 # add noise to latents using the timesteps 252 # add noise to latents using the timesteps
252 noise = torch.randn(latents.shape, generator=generator, device=self.device) 253 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
253 latents = self.scheduler.add_noise(latents, noise, timesteps) 254 latents = self.scheduler.add_noise(latents, noise, timesteps)
254 else: 255 else:
255 if latents.shape != latents_shape: 256 if latents.shape != latents_shape: