From 329ad48b307e782b0e23fce80ae9087a4003e73d Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Wed, 30 Nov 2022 14:02:35 +0100
Subject: Update

---
 .../stable_diffusion/vlpn_stable_diffusion.py      | 27 +++++++++++++++++++++-
 1 file changed, 26 insertions(+), 1 deletion(-)

(limited to 'pipelines/stable_diffusion')

diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 85b0216..c77c4d1 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -20,6 +20,7 @@ from diffusers import (
     PNDMScheduler,
 )
 from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.models.vae import DecoderOutput
 from diffusers.utils import logging
 from transformers import CLIPTextModel, CLIPTokenizer
 from models.clip.prompt import PromptProcessor
@@ -69,6 +70,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
             scheduler._internal_dict = FrozenDict(new_config)
 
         self.prompt_processor = PromptProcessor(tokenizer, text_encoder)
+        self.use_slicing = False
 
         self.register_modules(
             vae=vae,
@@ -136,6 +138,21 @@ class VlpnStableDiffusion(DiffusionPipeline):
             if cpu_offloaded_model is not None:
                 cpu_offload(cpu_offloaded_model, device)
 
+    def enable_vae_slicing(self):
+        r"""
+        Enable sliced VAE decoding.
+        When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+        steps. This is useful to save some memory and allow larger batch sizes.
+        """
+        self.use_slicing = True
+
+    def disable_vae_slicing(self):
+        r"""
+        Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+        computing decoding in one step.
+        """
+        self.use_slicing = False
+
     @property
     def execution_device(self):
         r"""
@@ -280,12 +297,20 @@ class VlpnStableDiffusion(DiffusionPipeline):
 
     def decode_latents(self, latents):
         latents = 1 / 0.18215 * latents
-        image = self.vae.decode(latents).sample
+        image = self.vae_decode(latents).sample
         image = (image / 2 + 0.5).clamp(0, 1)
         # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
         image = image.cpu().permute(0, 2, 3, 1).float().numpy()
         return image
 
+    def vae_decode(self, latents):
+        if self.use_slicing:
+            decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)]
+            decoded = torch.cat(decoded_slices)
+            return DecoderOutput(sample=decoded)
+        else:
+            return self.vae.decode(latents)
+
     @torch.no_grad()
     def __call__(
         self,
-- 
cgit v1.2.3-70-g09d2