diff options
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 17 |
1 files changed, 6 insertions, 11 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4c793a8..a8ecedf 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -185,6 +185,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist |
186 | latents = latent_dist.sample(generator=generator) | 186 | latents = latent_dist.sample(generator=generator) |
187 | latents = 0.18215 * latents | 187 | latents = 0.18215 * latents |
188 | |||
189 | # expand init_latents for batch_size | ||
188 | latents = torch.cat([latents] * batch_size) | 190 | latents = torch.cat([latents] * batch_size) |
189 | 191 | ||
190 | # get the original timestep using init_timestep | 192 | # get the original timestep using init_timestep |
@@ -195,9 +197,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
195 | timesteps = torch.tensor( | 197 | timesteps = torch.tensor( |
196 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | 198 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device |
197 | ) | 199 | ) |
198 | elif isinstance(self.scheduler, EulerAScheduler): | ||
199 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
200 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
201 | else: | 200 | else: |
202 | timesteps = self.scheduler.timesteps[-init_timestep] | 201 | timesteps = self.scheduler.timesteps[-init_timestep] |
203 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | 202 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) |
@@ -273,8 +272,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
273 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 272 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
274 | latents = latents * self.scheduler.sigmas[0] | 273 | latents = latents * self.scheduler.sigmas[0] |
275 | elif isinstance(self.scheduler, EulerAScheduler): | 274 | elif isinstance(self.scheduler, EulerAScheduler): |
276 | sigma = self.scheduler.timesteps[0] | 275 | latents = latents * self.scheduler.sigmas[0] |
277 | latents = latents * sigma | ||
278 | 276 | ||
279 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 277 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
280 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 278 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
@@ -301,12 +299,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
301 | 299 | ||
302 | noise_pred = None | 300 | noise_pred = None |
303 | if isinstance(self.scheduler, EulerAScheduler): | 301 | if isinstance(self.scheduler, EulerAScheduler): |
304 | sigma = t.reshape(1) | 302 | sigma = self.scheduler.sigmas[t].reshape(1) |
305 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) | 303 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) |
306 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) | ||
307 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 304 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, |
308 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) | 305 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) |
309 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample | ||
310 | else: | 306 | else: |
311 | # predict the noise residual | 307 | # predict the noise residual |
312 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 308 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
@@ -320,9 +316,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
320 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 316 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
321 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample | 317 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample |
322 | elif isinstance(self.scheduler, EulerAScheduler): | 318 | elif isinstance(self.scheduler, EulerAScheduler): |
323 | if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error | 319 | latents = self.scheduler.step(noise_pred, t_index, t_index + 1, |
324 | t_prev = self.scheduler.timesteps[t_index+1] | 320 | latents, **extra_step_kwargs).prev_sample |
325 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample | ||
326 | else: | 321 | else: |
327 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 322 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
328 | 323 | ||