diff options
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 41 |
1 files changed, 20 insertions, 21 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 98703d5..204276e 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -9,6 +9,7 @@ import torch.nn.functional as F | |||
9 | import PIL | 9 | import PIL |
10 | 10 | ||
11 | from diffusers.configuration_utils import FrozenDict | 11 | from diffusers.configuration_utils import FrozenDict |
12 | from diffusers.image_processor import VaeImageProcessor | ||
12 | from diffusers.utils import is_accelerate_available | 13 | from diffusers.utils import is_accelerate_available |
13 | from diffusers import ( | 14 | from diffusers import ( |
14 | AutoencoderKL, | 15 | AutoencoderKL, |
@@ -161,6 +162,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
161 | scheduler=scheduler, | 162 | scheduler=scheduler, |
162 | ) | 163 | ) |
163 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | 164 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
165 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | ||
164 | 166 | ||
165 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 167 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
166 | r""" | 168 | r""" |
@@ -443,14 +445,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
443 | extra_step_kwargs["generator"] = generator | 445 | extra_step_kwargs["generator"] = generator |
444 | return extra_step_kwargs | 446 | return extra_step_kwargs |
445 | 447 | ||
446 | def decode_latents(self, latents): | ||
447 | latents = 1 / self.vae.config.scaling_factor * latents | ||
448 | image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] | ||
449 | image = (image / 2 + 0.5).clamp(0, 1) | ||
450 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | ||
451 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | ||
452 | return image | ||
453 | |||
454 | @torch.no_grad() | 448 | @torch.no_grad() |
455 | def __call__( | 449 | def __call__( |
456 | self, | 450 | self, |
@@ -544,6 +538,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
544 | do_classifier_free_guidance = guidance_scale > 1.0 | 538 | do_classifier_free_guidance = guidance_scale > 1.0 |
545 | do_self_attention_guidance = sag_scale > 0.0 | 539 | do_self_attention_guidance = sag_scale > 0.0 |
546 | prep_from_image = isinstance(image, PIL.Image.Image) | 540 | prep_from_image = isinstance(image, PIL.Image.Image) |
541 | if not prep_from_image: | ||
542 | strength = 1 | ||
547 | 543 | ||
548 | # 3. Encode input prompt | 544 | # 3. Encode input prompt |
549 | prompt_embeds = self.encode_prompt( | 545 | prompt_embeds = self.encode_prompt( |
@@ -577,7 +573,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
577 | ) | 573 | ) |
578 | else: | 574 | else: |
579 | latents = self.prepare_latents( | 575 | latents = self.prepare_latents( |
580 | batch_size, | 576 | batch_size * num_images_per_prompt, |
581 | num_channels_latents, | 577 | num_channels_latents, |
582 | height, | 578 | height, |
583 | width, | 579 | width, |
@@ -623,9 +619,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
623 | noise_pred = noise_pred_uncond + guidance_scale * ( | 619 | noise_pred = noise_pred_uncond + guidance_scale * ( |
624 | noise_pred_text - noise_pred_uncond | 620 | noise_pred_text - noise_pred_uncond |
625 | ) | 621 | ) |
626 | noise_pred = rescale_noise_cfg( | 622 | if guidance_rescale > 0.0: |
627 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale | 623 | noise_pred = rescale_noise_cfg( |
628 | ) | 624 | noise_pred, |
625 | noise_pred_text, | ||
626 | guidance_rescale=guidance_rescale, | ||
627 | ) | ||
629 | 628 | ||
630 | if do_self_attention_guidance: | 629 | if do_self_attention_guidance: |
631 | # classifier-free guidance produces two chunks of attention map | 630 | # classifier-free guidance produces two chunks of attention map |
@@ -690,17 +689,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
690 | 689 | ||
691 | has_nsfw_concept = None | 690 | has_nsfw_concept = None |
692 | 691 | ||
693 | if output_type == "latent": | 692 | if not output_type == "latent": |
693 | image = self.vae.decode( | ||
694 | latents / self.vae.config.scaling_factor, return_dict=False | ||
695 | )[0] | ||
696 | else: | ||
694 | image = latents | 697 | image = latents |
695 | elif output_type == "pil": | ||
696 | # 9. Post-processing | ||
697 | image = self.decode_latents(latents) | ||
698 | 698 | ||
699 | # 10. Convert to PIL | 699 | do_denormalize = [True] * image.shape[0] |
700 | image = self.numpy_to_pil(image) | 700 | image = self.image_processor.postprocess( |
701 | else: | 701 | image, output_type=output_type, do_denormalize=do_denormalize |
702 | # 9. Post-processing | 702 | ) |
703 | image = self.decode_latents(latents) | ||
704 | 703 | ||
705 | # Offload last model to CPU | 704 | # Offload last model to CPU |
706 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | 705 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |