diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index bfecd1c..8927a78 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre | |||
| 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
| 13 | from transformers import CLIPTextModel, CLIPTokenizer | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
| 14 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | 14 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 15 | 15 | ||
| 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 17 | 17 | ||
| @@ -284,10 +284,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 284 | 284 | ||
| 285 | noise_pred = None | 285 | noise_pred = None |
| 286 | if isinstance(self.scheduler, EulerAScheduler): | 286 | if isinstance(self.scheduler, EulerAScheduler): |
| 287 | sigma = t.reshape(1) | 287 | c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size) |
| 288 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) | 288 | eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample |
| 289 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 289 | noise_pred = latent_model_input + eps * c_out |
| 290 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) | ||
| 291 | else: | 290 | else: |
| 292 | # predict the noise residual | 291 | # predict the noise residual |
| 293 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 292 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
| @@ -305,7 +304,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 305 | image = self.vae.decode(latents).sample | 304 | image = self.vae.decode(latents).sample |
| 306 | 305 | ||
| 307 | image = (image / 2 + 0.5).clamp(0, 1) | 306 | image = (image / 2 + 0.5).clamp(0, 1) |
| 308 | image = image.cpu().permute(0, 2, 3, 1).numpy() | 307 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| 309 | 308 | ||
| 310 | if output_type == "pil": | 309 | if output_type == "pil": |
| 311 | image = self.numpy_to_pil(image) | 310 | image = self.numpy_to_pil(image) |
