From 329ad48b307e782b0e23fce80ae9087a4003e73d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 30 Nov 2022 14:02:35 +0100 Subject: Update --- .../stable_diffusion/vlpn_stable_diffusion.py | 27 +++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) (limited to 'pipelines') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 85b0216..c77c4d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -20,6 +20,7 @@ from diffusers import ( PNDMScheduler, ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from diffusers.models.vae import DecoderOutput from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer from models.clip.prompt import PromptProcessor @@ -69,6 +70,7 @@ class VlpnStableDiffusion(DiffusionPipeline): scheduler._internal_dict = FrozenDict(new_config) self.prompt_processor = PromptProcessor(tokenizer, text_encoder) + self.use_slicing = False self.register_modules( vae=vae, @@ -136,6 +138,21 @@ class VlpnStableDiffusion(DiffusionPipeline): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.use_slicing = False + @property def execution_device(self): r""" @@ -280,12 +297,20 @@ class VlpnStableDiffusion(DiffusionPipeline): def decode_latents(self, latents): latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + image = self.vae_decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image + def vae_decode(self, latents): + if self.use_slicing: + decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] + decoded = torch.cat(decoded_slices) + return DecoderOutput(sample=decoded) + else: + return self.vae.decode(latents) + @torch.no_grad() def __call__( self, -- cgit v1.2.3-54-g00ecf