diff options
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index bfecd1c..8927a78 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre | |||
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
13 | from transformers import CLIPTextModel, CLIPTokenizer | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
14 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | 14 | from schedulers.scheduling_euler_a import EulerAScheduler |
15 | 15 | ||
16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
17 | 17 | ||
@@ -284,10 +284,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
284 | 284 | ||
285 | noise_pred = None | 285 | noise_pred = None |
286 | if isinstance(self.scheduler, EulerAScheduler): | 286 | if isinstance(self.scheduler, EulerAScheduler): |
287 | sigma = t.reshape(1) | 287 | c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size) |
288 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) | 288 | eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample |
289 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 289 | noise_pred = latent_model_input + eps * c_out |
290 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) | ||
291 | else: | 290 | else: |
292 | # predict the noise residual | 291 | # predict the noise residual |
293 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 292 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
@@ -305,7 +304,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
305 | image = self.vae.decode(latents).sample | 304 | image = self.vae.decode(latents).sample |
306 | 305 | ||
307 | image = (image / 2 + 0.5).clamp(0, 1) | 306 | image = (image / 2 + 0.5).clamp(0, 1) |
308 | image = image.cpu().permute(0, 2, 3, 1).numpy() | 307 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
309 | 308 | ||
310 | if output_type == "pil": | 309 | if output_type == "pil": |
311 | image = self.numpy_to_pil(image) | 310 | image = self.numpy_to_pil(image) |