From c5cc1318c2a7597fe62d3379e50187d0b0f22538 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 1 Dec 2022 13:45:21 +0100 Subject: Update --- .../stable_diffusion/vlpn_stable_diffusion.py | 27 +++++++++------------- 1 file changed, 11 insertions(+), 16 deletions(-) (limited to 'pipelines') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index c77c4d1..9b51763 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -20,7 +20,6 @@ from diffusers import ( PNDMScheduler, ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput -from diffusers.models.vae import DecoderOutput from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer from models.clip.prompt import PromptProcessor @@ -70,7 +69,6 @@ class VlpnStableDiffusion(DiffusionPipeline): scheduler._internal_dict = FrozenDict(new_config) self.prompt_processor = PromptProcessor(tokenizer, text_encoder) - self.use_slicing = False self.register_modules( vae=vae, @@ -108,9 +106,14 @@ class VlpnStableDiffusion(DiffusionPipeline): `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + self.unet.set_attention_slice(slice_size) def disable_attention_slicing(self): @@ -144,14 +147,14 @@ class VlpnStableDiffusion(DiffusionPipeline): When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ - self.use_slicing = True + self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ - self.use_slicing = False + self.vae.disable_slicing() @property def execution_device(self): @@ -297,20 +300,12 @@ class VlpnStableDiffusion(DiffusionPipeline): def decode_latents(self, latents): latents = 1 / 0.18215 * latents - image = self.vae_decode(latents).sample + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image - def vae_decode(self, latents): - if self.use_slicing: - decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] - decoded = torch.cat(decoded_slices) - return DecoderOutput(sample=decoded) - else: - return self.vae.decode(latents) - @torch.no_grad() def __call__( self, -- cgit v1.2.3-54-g00ecf