summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-01 16:19:01 +0100
committerVolpeon <git@volpeon.ink>2022-11-01 16:19:01 +0100
commitb2c3389e9c6375d9081625e75a99de98395f8e77 (patch)
treed230b417314960e8705abd2eeaa3b55d9b70c754 /pipelines/stable_diffusion
parentFix (diff)
downloadtextual-inversion-diff-b2c3389e9c6375d9081625e75a99de98395f8e77.tar.gz
textual-inversion-diff-b2c3389e9c6375d9081625e75a99de98395f8e77.tar.bz2
textual-inversion-diff-b2c3389e9c6375d9081625e75a99de98395f8e77.zip
Update
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index fc12355..cd5ae7e 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -203,6 +203,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
203 # However this currently doesn't work in `mps`. 203 # However this currently doesn't work in `mps`.
204 latents_dtype = text_embeddings.dtype 204 latents_dtype = text_embeddings.dtype
205 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) 205 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
206
206 if latents is None: 207 if latents is None:
207 if self.device.type == "mps": 208 if self.device.type == "mps":
208 # randn does not exist on mps 209 # randn does not exist on mps
@@ -264,7 +265,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
264 for i, t in enumerate(self.progress_bar(timesteps_tensor)): 265 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
265 # expand the latents if we are doing classifier free guidance 266 # expand the latents if we are doing classifier free guidance
266 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 267 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
267 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i) 268 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
268 269
269 # predict the noise residual 270 # predict the noise residual
270 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 271 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -275,7 +276,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
275 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 276 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
276 277
277 # compute the previous noisy sample x_t -> x_t-1 278 # compute the previous noisy sample x_t -> x_t-1
278 latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample 279 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
279 280
280 # scale and decode the image latents with vae 281 # scale and decode the image latents with vae
281 latents = 1 / 0.18215 * latents 282 latents = 1 / 0.18215 * latents