diff options
Diffstat (limited to 'pipelines')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 27 |
1 files changed, 11 insertions, 16 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index c77c4d1..9b51763 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -20,7 +20,6 @@ from diffusers import ( | |||
| 20 | PNDMScheduler, | 20 | PNDMScheduler, |
| 21 | ) | 21 | ) |
| 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 23 | from diffusers.models.vae import DecoderOutput | ||
| 24 | from diffusers.utils import logging | 23 | from diffusers.utils import logging |
| 25 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
| 26 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.prompt import PromptProcessor |
| @@ -70,7 +69,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 70 | scheduler._internal_dict = FrozenDict(new_config) | 69 | scheduler._internal_dict = FrozenDict(new_config) |
| 71 | 70 | ||
| 72 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | 71 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 73 | self.use_slicing = False | ||
| 74 | 72 | ||
| 75 | self.register_modules( | 73 | self.register_modules( |
| 76 | vae=vae, | 74 | vae=vae, |
| @@ -108,9 +106,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 108 | `attention_head_dim` must be a multiple of `slice_size`. | 106 | `attention_head_dim` must be a multiple of `slice_size`. |
| 109 | """ | 107 | """ |
| 110 | if slice_size == "auto": | 108 | if slice_size == "auto": |
| 111 | # half the attention head size is usually a good trade-off between | 109 | if isinstance(self.unet.config.attention_head_dim, int): |
| 112 | # speed and memory | 110 | # half the attention head size is usually a good trade-off between |
| 113 | slice_size = self.unet.config.attention_head_dim // 2 | 111 | # speed and memory |
| 112 | slice_size = self.unet.config.attention_head_dim // 2 | ||
| 113 | else: | ||
| 114 | # if `attention_head_dim` is a list, take the smallest head size | ||
| 115 | slice_size = min(self.unet.config.attention_head_dim) | ||
| 116 | |||
| 114 | self.unet.set_attention_slice(slice_size) | 117 | self.unet.set_attention_slice(slice_size) |
| 115 | 118 | ||
| 116 | def disable_attention_slicing(self): | 119 | def disable_attention_slicing(self): |
| @@ -144,14 +147,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 144 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several | 147 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several |
| 145 | steps. This is useful to save some memory and allow larger batch sizes. | 148 | steps. This is useful to save some memory and allow larger batch sizes. |
| 146 | """ | 149 | """ |
| 147 | self.use_slicing = True | 150 | self.vae.enable_slicing() |
| 148 | 151 | ||
| 149 | def disable_vae_slicing(self): | 152 | def disable_vae_slicing(self): |
| 150 | r""" | 153 | r""" |
| 151 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to | 154 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to |
| 152 | computing decoding in one step. | 155 | computing decoding in one step. |
| 153 | """ | 156 | """ |
| 154 | self.use_slicing = False | 157 | self.vae.disable_slicing() |
| 155 | 158 | ||
| 156 | @property | 159 | @property |
| 157 | def execution_device(self): | 160 | def execution_device(self): |
| @@ -297,20 +300,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 297 | 300 | ||
| 298 | def decode_latents(self, latents): | 301 | def decode_latents(self, latents): |
| 299 | latents = 1 / 0.18215 * latents | 302 | latents = 1 / 0.18215 * latents |
| 300 | image = self.vae_decode(latents).sample | 303 | image = self.vae.decode(latents).sample |
| 301 | image = (image / 2 + 0.5).clamp(0, 1) | 304 | image = (image / 2 + 0.5).clamp(0, 1) |
| 302 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 305 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
| 303 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 306 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| 304 | return image | 307 | return image |
| 305 | 308 | ||
| 306 | def vae_decode(self, latents): | ||
| 307 | if self.use_slicing: | ||
| 308 | decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] | ||
| 309 | decoded = torch.cat(decoded_slices) | ||
| 310 | return DecoderOutput(sample=decoded) | ||
| 311 | else: | ||
| 312 | return self.vae.decode(latents) | ||
| 313 | |||
| 314 | @torch.no_grad() | 309 | @torch.no_grad() |
| 315 | def __call__( | 310 | def __call__( |
| 316 | self, | 311 | self, |
