diff options
Diffstat (limited to 'pipelines')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 17 |
1 files changed, 6 insertions, 11 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4c793a8..a8ecedf 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -185,6 +185,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist |
| 186 | latents = latent_dist.sample(generator=generator) | 186 | latents = latent_dist.sample(generator=generator) |
| 187 | latents = 0.18215 * latents | 187 | latents = 0.18215 * latents |
| 188 | |||
| 189 | # expand init_latents for batch_size | ||
| 188 | latents = torch.cat([latents] * batch_size) | 190 | latents = torch.cat([latents] * batch_size) |
| 189 | 191 | ||
| 190 | # get the original timestep using init_timestep | 192 | # get the original timestep using init_timestep |
| @@ -195,9 +197,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 195 | timesteps = torch.tensor( | 197 | timesteps = torch.tensor( |
| 196 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | 198 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device |
| 197 | ) | 199 | ) |
| 198 | elif isinstance(self.scheduler, EulerAScheduler): | ||
| 199 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 200 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
| 201 | else: | 200 | else: |
| 202 | timesteps = self.scheduler.timesteps[-init_timestep] | 201 | timesteps = self.scheduler.timesteps[-init_timestep] |
| 203 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | 202 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) |
| @@ -273,8 +272,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 273 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 272 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
| 274 | latents = latents * self.scheduler.sigmas[0] | 273 | latents = latents * self.scheduler.sigmas[0] |
| 275 | elif isinstance(self.scheduler, EulerAScheduler): | 274 | elif isinstance(self.scheduler, EulerAScheduler): |
| 276 | sigma = self.scheduler.timesteps[0] | 275 | latents = latents * self.scheduler.sigmas[0] |
| 277 | latents = latents * sigma | ||
| 278 | 276 | ||
| 279 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
| 280 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
| @@ -301,12 +299,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 301 | 299 | ||
| 302 | noise_pred = None | 300 | noise_pred = None |
| 303 | if isinstance(self.scheduler, EulerAScheduler): | 301 | if isinstance(self.scheduler, EulerAScheduler): |
| 304 | sigma = t.reshape(1) | 302 | sigma = self.scheduler.sigmas[t].reshape(1) |
| 305 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) | 303 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) |
| 306 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) | ||
| 307 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 304 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, |
| 308 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) | 305 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) |
| 309 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample | ||
| 310 | else: | 306 | else: |
| 311 | # predict the noise residual | 307 | # predict the noise residual |
| 312 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 308 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
| @@ -320,9 +316,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 320 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 316 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
| 321 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample | 317 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample |
| 322 | elif isinstance(self.scheduler, EulerAScheduler): | 318 | elif isinstance(self.scheduler, EulerAScheduler): |
| 323 | if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error | 319 | latents = self.scheduler.step(noise_pred, t_index, t_index + 1, |
| 324 | t_prev = self.scheduler.timesteps[t_index+1] | 320 | latents, **extra_step_kwargs).prev_sample |
| 325 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample | ||
| 326 | else: | 321 | else: |
| 327 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 322 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
| 328 | 323 | ||
