diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-24 16:26:22 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-24 16:26:22 +0200 |
| commit | 27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712 (patch) | |
| tree | 6c1f2243475778bb5e9e1725bf3969a5442393d8 /pipelines | |
| parent | Update (diff) | |
| download | textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.tar.gz textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.tar.bz2 textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.zip | |
Fixes
Diffstat (limited to 'pipelines')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 41 |
1 files changed, 20 insertions, 21 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 98703d5..204276e 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -9,6 +9,7 @@ import torch.nn.functional as F | |||
| 9 | import PIL | 9 | import PIL |
| 10 | 10 | ||
| 11 | from diffusers.configuration_utils import FrozenDict | 11 | from diffusers.configuration_utils import FrozenDict |
| 12 | from diffusers.image_processor import VaeImageProcessor | ||
| 12 | from diffusers.utils import is_accelerate_available | 13 | from diffusers.utils import is_accelerate_available |
| 13 | from diffusers import ( | 14 | from diffusers import ( |
| 14 | AutoencoderKL, | 15 | AutoencoderKL, |
| @@ -161,6 +162,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 161 | scheduler=scheduler, | 162 | scheduler=scheduler, |
| 162 | ) | 163 | ) |
| 163 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | 164 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| 165 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | ||
| 164 | 166 | ||
| 165 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 167 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
| 166 | r""" | 168 | r""" |
| @@ -443,14 +445,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 443 | extra_step_kwargs["generator"] = generator | 445 | extra_step_kwargs["generator"] = generator |
| 444 | return extra_step_kwargs | 446 | return extra_step_kwargs |
| 445 | 447 | ||
| 446 | def decode_latents(self, latents): | ||
| 447 | latents = 1 / self.vae.config.scaling_factor * latents | ||
| 448 | image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] | ||
| 449 | image = (image / 2 + 0.5).clamp(0, 1) | ||
| 450 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | ||
| 451 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | ||
| 452 | return image | ||
| 453 | |||
| 454 | @torch.no_grad() | 448 | @torch.no_grad() |
| 455 | def __call__( | 449 | def __call__( |
| 456 | self, | 450 | self, |
| @@ -544,6 +538,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 544 | do_classifier_free_guidance = guidance_scale > 1.0 | 538 | do_classifier_free_guidance = guidance_scale > 1.0 |
| 545 | do_self_attention_guidance = sag_scale > 0.0 | 539 | do_self_attention_guidance = sag_scale > 0.0 |
| 546 | prep_from_image = isinstance(image, PIL.Image.Image) | 540 | prep_from_image = isinstance(image, PIL.Image.Image) |
| 541 | if not prep_from_image: | ||
| 542 | strength = 1 | ||
| 547 | 543 | ||
| 548 | # 3. Encode input prompt | 544 | # 3. Encode input prompt |
| 549 | prompt_embeds = self.encode_prompt( | 545 | prompt_embeds = self.encode_prompt( |
| @@ -577,7 +573,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 577 | ) | 573 | ) |
| 578 | else: | 574 | else: |
| 579 | latents = self.prepare_latents( | 575 | latents = self.prepare_latents( |
| 580 | batch_size, | 576 | batch_size * num_images_per_prompt, |
| 581 | num_channels_latents, | 577 | num_channels_latents, |
| 582 | height, | 578 | height, |
| 583 | width, | 579 | width, |
| @@ -623,9 +619,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 623 | noise_pred = noise_pred_uncond + guidance_scale * ( | 619 | noise_pred = noise_pred_uncond + guidance_scale * ( |
| 624 | noise_pred_text - noise_pred_uncond | 620 | noise_pred_text - noise_pred_uncond |
| 625 | ) | 621 | ) |
| 626 | noise_pred = rescale_noise_cfg( | 622 | if guidance_rescale > 0.0: |
| 627 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale | 623 | noise_pred = rescale_noise_cfg( |
| 628 | ) | 624 | noise_pred, |
| 625 | noise_pred_text, | ||
| 626 | guidance_rescale=guidance_rescale, | ||
| 627 | ) | ||
| 629 | 628 | ||
| 630 | if do_self_attention_guidance: | 629 | if do_self_attention_guidance: |
| 631 | # classifier-free guidance produces two chunks of attention map | 630 | # classifier-free guidance produces two chunks of attention map |
| @@ -690,17 +689,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 690 | 689 | ||
| 691 | has_nsfw_concept = None | 690 | has_nsfw_concept = None |
| 692 | 691 | ||
| 693 | if output_type == "latent": | 692 | if not output_type == "latent": |
| 693 | image = self.vae.decode( | ||
| 694 | latents / self.vae.config.scaling_factor, return_dict=False | ||
| 695 | )[0] | ||
| 696 | else: | ||
| 694 | image = latents | 697 | image = latents |
| 695 | elif output_type == "pil": | ||
| 696 | # 9. Post-processing | ||
| 697 | image = self.decode_latents(latents) | ||
| 698 | 698 | ||
| 699 | # 10. Convert to PIL | 699 | do_denormalize = [True] * image.shape[0] |
| 700 | image = self.numpy_to_pil(image) | 700 | image = self.image_processor.postprocess( |
| 701 | else: | 701 | image, output_type=output_type, do_denormalize=do_denormalize |
| 702 | # 9. Post-processing | 702 | ) |
| 703 | image = self.decode_latents(latents) | ||
| 704 | 703 | ||
| 705 | # Offload last model to CPU | 704 | # Offload last model to CPU |
| 706 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | 705 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
