From f5b656d21c5b449eed6ce212e909043c124f79ee Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 12 Oct 2022 08:18:22 +0200 Subject: Various updates --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index bfecd1c..8927a78 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer -from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward +from schedulers.scheduling_euler_a import EulerAScheduler logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -284,10 +284,9 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred = None if isinstance(self.scheduler, EulerAScheduler): - sigma = t.reshape(1) - sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) - noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, - text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) + c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size) + eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample + noise_pred = latent_model_input + eps * c_out else: # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -305,7 +304,7 @@ class VlpnStableDiffusion(DiffusionPipeline): image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) -- cgit v1.2.3-70-g09d2