diff options
author | Volpeon <git@volpeon.ink> | 2022-11-01 16:19:01 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-01 16:19:01 +0100 |
commit | b2c3389e9c6375d9081625e75a99de98395f8e77 (patch) | |
tree | d230b417314960e8705abd2eeaa3b55d9b70c754 /pipelines/stable_diffusion | |
parent | Fix (diff) | |
download | textual-inversion-diff-b2c3389e9c6375d9081625e75a99de98395f8e77.tar.gz textual-inversion-diff-b2c3389e9c6375d9081625e75a99de98395f8e77.tar.bz2 textual-inversion-diff-b2c3389e9c6375d9081625e75a99de98395f8e77.zip |
Update
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index fc12355..cd5ae7e 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -203,6 +203,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
203 | # However this currently doesn't work in `mps`. | 203 | # However this currently doesn't work in `mps`. |
204 | latents_dtype = text_embeddings.dtype | 204 | latents_dtype = text_embeddings.dtype |
205 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) | 205 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) |
206 | |||
206 | if latents is None: | 207 | if latents is None: |
207 | if self.device.type == "mps": | 208 | if self.device.type == "mps": |
208 | # randn does not exist on mps | 209 | # randn does not exist on mps |
@@ -264,7 +265,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
264 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 265 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): |
265 | # expand the latents if we are doing classifier free guidance | 266 | # expand the latents if we are doing classifier free guidance |
266 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 267 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
267 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i) | 268 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
268 | 269 | ||
269 | # predict the noise residual | 270 | # predict the noise residual |
270 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 271 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
@@ -275,7 +276,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
275 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 276 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
276 | 277 | ||
277 | # compute the previous noisy sample x_t -> x_t-1 | 278 | # compute the previous noisy sample x_t -> x_t-1 |
278 | latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample | 279 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
279 | 280 | ||
280 | # scale and decode the image latents with vae | 281 | # scale and decode the image latents with vae |
281 | latents = 1 / 0.18215 * latents | 282 | latents = 1 / 0.18215 * latents |