From 49463992f48ec25f2ea31b220a6cedac3466467a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 26 Oct 2022 11:11:33 +0200 Subject: New Euler_a scheduler --- .../stable_diffusion/vlpn_stable_diffusion.py | 27 +++++++++++----------- 1 file changed, 13 insertions(+), 14 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index e90528d..fc12355 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer -from schedulers.scheduling_euler_a import EulerAScheduler +from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from models.clip.prompt import PromptProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -32,7 +32,7 @@ class VlpnStableDiffusion(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler], **kwargs, ): super().__init__() @@ -225,8 +225,13 @@ class VlpnStableDiffusion(DiffusionPipeline): init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, device=self.device) + if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, device=self.device) # add noise to latents using the timesteps noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) @@ -259,16 +264,10 @@ class VlpnStableDiffusion(DiffusionPipeline): for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i) - noise_pred = None - if isinstance(self.scheduler, EulerAScheduler): - c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size) - eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample - noise_pred = latent_model_input + eps * c_out - else: - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: @@ -276,7 +275,7 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents -- cgit v1.2.3-54-g00ecf