From 13b0d9f763269df405d1aeba86213f1c7ce4e7ca Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 2 Oct 2022 15:14:29 +0200 Subject: More consistent euler_a --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4c793a8..a8ecedf 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -185,6 +185,8 @@ class VlpnStableDiffusion(DiffusionPipeline): latent_dist = self.vae.encode(latents.to(self.device)).latent_dist latents = latent_dist.sample(generator=generator) latents = 0.18215 * latents + + # expand init_latents for batch_size latents = torch.cat([latents] * batch_size) # get the original timestep using init_timestep @@ -195,9 +197,6 @@ class VlpnStableDiffusion(DiffusionPipeline): timesteps = torch.tensor( [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device ) - elif isinstance(self.scheduler, EulerAScheduler): - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, device=self.device) else: timesteps = self.scheduler.timesteps[-init_timestep] timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) @@ -273,8 +272,7 @@ class VlpnStableDiffusion(DiffusionPipeline): if isinstance(self.scheduler, LMSDiscreteScheduler): latents = latents * self.scheduler.sigmas[0] elif isinstance(self.scheduler, EulerAScheduler): - sigma = self.scheduler.timesteps[0] - latents = latents * sigma + latents = latents * self.scheduler.sigmas[0] # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -301,12 +299,10 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred = None if isinstance(self.scheduler, EulerAScheduler): - sigma = t.reshape(1) + sigma = self.scheduler.sigmas[t].reshape(1) sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) - # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) - # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample else: # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -320,9 +316,8 @@ class VlpnStableDiffusion(DiffusionPipeline): if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample elif isinstance(self.scheduler, EulerAScheduler): - if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error - t_prev = self.scheduler.timesteps[t_index+1] - latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t_index, t_index + 1, + latents, **extra_step_kwargs).prev_sample else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample -- cgit v1.2.3-54-g00ecf