summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion/vlpn_stable_diffusion.py
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py41
1 files changed, 20 insertions, 21 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 98703d5..204276e 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -9,6 +9,7 @@ import torch.nn.functional as F
9import PIL 9import PIL
10 10
11from diffusers.configuration_utils import FrozenDict 11from diffusers.configuration_utils import FrozenDict
12from diffusers.image_processor import VaeImageProcessor
12from diffusers.utils import is_accelerate_available 13from diffusers.utils import is_accelerate_available
13from diffusers import ( 14from diffusers import (
14 AutoencoderKL, 15 AutoencoderKL,
@@ -161,6 +162,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
161 scheduler=scheduler, 162 scheduler=scheduler,
162 ) 163 )
163 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 164 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
165 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
164 166
165 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 167 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
166 r""" 168 r"""
@@ -443,14 +445,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
443 extra_step_kwargs["generator"] = generator 445 extra_step_kwargs["generator"] = generator
444 return extra_step_kwargs 446 return extra_step_kwargs
445 447
446 def decode_latents(self, latents):
447 latents = 1 / self.vae.config.scaling_factor * latents
448 image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
449 image = (image / 2 + 0.5).clamp(0, 1)
450 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
451 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
452 return image
453
454 @torch.no_grad() 448 @torch.no_grad()
455 def __call__( 449 def __call__(
456 self, 450 self,
@@ -544,6 +538,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
544 do_classifier_free_guidance = guidance_scale > 1.0 538 do_classifier_free_guidance = guidance_scale > 1.0
545 do_self_attention_guidance = sag_scale > 0.0 539 do_self_attention_guidance = sag_scale > 0.0
546 prep_from_image = isinstance(image, PIL.Image.Image) 540 prep_from_image = isinstance(image, PIL.Image.Image)
541 if not prep_from_image:
542 strength = 1
547 543
548 # 3. Encode input prompt 544 # 3. Encode input prompt
549 prompt_embeds = self.encode_prompt( 545 prompt_embeds = self.encode_prompt(
@@ -577,7 +573,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
577 ) 573 )
578 else: 574 else:
579 latents = self.prepare_latents( 575 latents = self.prepare_latents(
580 batch_size, 576 batch_size * num_images_per_prompt,
581 num_channels_latents, 577 num_channels_latents,
582 height, 578 height,
583 width, 579 width,
@@ -623,9 +619,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
623 noise_pred = noise_pred_uncond + guidance_scale * ( 619 noise_pred = noise_pred_uncond + guidance_scale * (
624 noise_pred_text - noise_pred_uncond 620 noise_pred_text - noise_pred_uncond
625 ) 621 )
626 noise_pred = rescale_noise_cfg( 622 if guidance_rescale > 0.0:
627 noise_pred, noise_pred_text, guidance_rescale=guidance_rescale 623 noise_pred = rescale_noise_cfg(
628 ) 624 noise_pred,
625 noise_pred_text,
626 guidance_rescale=guidance_rescale,
627 )
629 628
630 if do_self_attention_guidance: 629 if do_self_attention_guidance:
631 # classifier-free guidance produces two chunks of attention map 630 # classifier-free guidance produces two chunks of attention map
@@ -690,17 +689,17 @@ class VlpnStableDiffusion(DiffusionPipeline):
690 689
691 has_nsfw_concept = None 690 has_nsfw_concept = None
692 691
693 if output_type == "latent": 692 if not output_type == "latent":
693 image = self.vae.decode(
694 latents / self.vae.config.scaling_factor, return_dict=False
695 )[0]
696 else:
694 image = latents 697 image = latents
695 elif output_type == "pil":
696 # 9. Post-processing
697 image = self.decode_latents(latents)
698 698
699 # 10. Convert to PIL 699 do_denormalize = [True] * image.shape[0]
700 image = self.numpy_to_pil(image) 700 image = self.image_processor.postprocess(
701 else: 701 image, output_type=output_type, do_denormalize=do_denormalize
702 # 9. Post-processing 702 )
703 image = self.decode_latents(latents)
704 703
705 # Offload last model to CPU 704 # Offload last model to CPU
706 if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 705 if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: