From 49463992f48ec25f2ea31b220a6cedac3466467a Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Wed, 26 Oct 2022 11:11:33 +0200
Subject: New Euler_a scheduler

---
 .../stable_diffusion/vlpn_stable_diffusion.py      | 27 +++++++++++-----------
 1 file changed, 13 insertions(+), 14 deletions(-)

(limited to 'pipelines')

diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index e90528d..fc12355 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre
 from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
 from diffusers.utils import logging
 from transformers import CLIPTextModel, CLIPTokenizer
-from schedulers.scheduling_euler_a import EulerAScheduler
+from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
 from models.clip.prompt import PromptProcessor
 
 logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
@@ -32,7 +32,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
         text_encoder: CLIPTextModel,
         tokenizer: CLIPTokenizer,
         unet: UNet2DConditionModel,
-        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler],
+        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler],
         **kwargs,
     ):
         super().__init__()
@@ -225,8 +225,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
             init_timestep = int(num_inference_steps * strength) + offset
             init_timestep = min(init_timestep, num_inference_steps)
 
-            timesteps = self.scheduler.timesteps[-init_timestep]
-            timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
+            if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler):
+                timesteps = torch.tensor(
+                    [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
+                )
+            else:
+                timesteps = self.scheduler.timesteps[-init_timestep]
+                timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
 
             # add noise to latents using the timesteps
             noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
@@ -259,16 +264,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
         for i, t in enumerate(self.progress_bar(timesteps_tensor)):
             # expand the latents if we are doing classifier free guidance
             latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
-            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i)
 
-            noise_pred = None
-            if isinstance(self.scheduler, EulerAScheduler):
-                c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size)
-                eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample
-                noise_pred = latent_model_input + eps * c_out
-            else:
-                # predict the noise residual
-                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+            # predict the noise residual
+            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
 
             # perform guidance
             if do_classifier_free_guidance:
@@ -276,7 +275,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
                 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
             # compute the previous noisy sample x_t -> x_t-1
-            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+            latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample
 
         # scale and decode the image latents with vae
         latents = 1 / 0.18215 * latents
-- 
cgit v1.2.3-70-g09d2