From 2ad46871e2ead985445da2848a4eb7072b6e48aa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 14 Nov 2022 17:09:58 +0100 Subject: Update --- .../stable_diffusion/vlpn_stable_diffusion.py | 33 ++++++++++++++-------- 1 file changed, 22 insertions(+), 11 deletions(-) (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 36942f0..ba057ba 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -8,11 +8,20 @@ import PIL from diffusers.configuration_utils import FrozenDict from diffusers.utils import is_accelerate_available -from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DiffusionPipeline, + UNet2DConditionModel, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer -from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from models.clip.prompt import PromptProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -33,7 +42,14 @@ class VlpnStableDiffusion(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler], + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], **kwargs, ): super().__init__() @@ -252,19 +268,14 @@ class VlpnStableDiffusion(DiffusionPipeline): latents = 0.18215 * latents # expand init_latents for batch_size - latents = torch.cat([latents] * batch_size) + latents = torch.cat([latents] * batch_size, dim=0) # get the original timestep using init_timestep init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, device=self.device) + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, device=self.device) # add noise to latents using the timesteps noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) -- cgit v1.2.3-54-g00ecf