summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index bfecd1c..8927a78 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 12from diffusers.utils import logging
13from transformers import CLIPTextModel, CLIPTokenizer 13from transformers import CLIPTextModel, CLIPTokenizer
14from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward 14from schedulers.scheduling_euler_a import EulerAScheduler
15 15
16logger = logging.get_logger(__name__) # pylint: disable=invalid-name 16logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17 17
@@ -284,10 +284,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
284 284
285 noise_pred = None 285 noise_pred = None
286 if isinstance(self.scheduler, EulerAScheduler): 286 if isinstance(self.scheduler, EulerAScheduler):
287 sigma = t.reshape(1) 287 c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size)
288 sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) 288 eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample
289 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, 289 noise_pred = latent_model_input + eps * c_out
290 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas)
291 else: 290 else:
292 # predict the noise residual 291 # predict the noise residual
293 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 292 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -305,7 +304,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
305 image = self.vae.decode(latents).sample 304 image = self.vae.decode(latents).sample
306 305
307 image = (image / 2 + 0.5).clamp(0, 1) 306 image = (image / 2 + 0.5).clamp(0, 1)
308 image = image.cpu().permute(0, 2, 3, 1).numpy() 307 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
309 308
310 if output_type == "pil": 309 if output_type == "pil":
311 image = self.numpy_to_pil(image) 310 image = self.numpy_to_pil(image)