summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py27
1 files changed, 11 insertions, 16 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index c77c4d1..9b51763 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -20,7 +20,6 @@ from diffusers import (
20 PNDMScheduler, 20 PNDMScheduler,
21) 21)
22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
23from diffusers.models.vae import DecoderOutput
24from diffusers.utils import logging 23from diffusers.utils import logging
25from transformers import CLIPTextModel, CLIPTokenizer 24from transformers import CLIPTextModel, CLIPTokenizer
26from models.clip.prompt import PromptProcessor 25from models.clip.prompt import PromptProcessor
@@ -70,7 +69,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
70 scheduler._internal_dict = FrozenDict(new_config) 69 scheduler._internal_dict = FrozenDict(new_config)
71 70
72 self.prompt_processor = PromptProcessor(tokenizer, text_encoder) 71 self.prompt_processor = PromptProcessor(tokenizer, text_encoder)
73 self.use_slicing = False
74 72
75 self.register_modules( 73 self.register_modules(
76 vae=vae, 74 vae=vae,
@@ -108,9 +106,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
108 `attention_head_dim` must be a multiple of `slice_size`. 106 `attention_head_dim` must be a multiple of `slice_size`.
109 """ 107 """
110 if slice_size == "auto": 108 if slice_size == "auto":
111 # half the attention head size is usually a good trade-off between 109 if isinstance(self.unet.config.attention_head_dim, int):
112 # speed and memory 110 # half the attention head size is usually a good trade-off between
113 slice_size = self.unet.config.attention_head_dim // 2 111 # speed and memory
112 slice_size = self.unet.config.attention_head_dim // 2
113 else:
114 # if `attention_head_dim` is a list, take the smallest head size
115 slice_size = min(self.unet.config.attention_head_dim)
116
114 self.unet.set_attention_slice(slice_size) 117 self.unet.set_attention_slice(slice_size)
115 118
116 def disable_attention_slicing(self): 119 def disable_attention_slicing(self):
@@ -144,14 +147,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
144 When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several 147 When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
145 steps. This is useful to save some memory and allow larger batch sizes. 148 steps. This is useful to save some memory and allow larger batch sizes.
146 """ 149 """
147 self.use_slicing = True 150 self.vae.enable_slicing()
148 151
149 def disable_vae_slicing(self): 152 def disable_vae_slicing(self):
150 r""" 153 r"""
151 Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to 154 Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
152 computing decoding in one step. 155 computing decoding in one step.
153 """ 156 """
154 self.use_slicing = False 157 self.vae.disable_slicing()
155 158
156 @property 159 @property
157 def execution_device(self): 160 def execution_device(self):
@@ -297,20 +300,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
297 300
298 def decode_latents(self, latents): 301 def decode_latents(self, latents):
299 latents = 1 / 0.18215 * latents 302 latents = 1 / 0.18215 * latents
300 image = self.vae_decode(latents).sample 303 image = self.vae.decode(latents).sample
301 image = (image / 2 + 0.5).clamp(0, 1) 304 image = (image / 2 + 0.5).clamp(0, 1)
302 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 305 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
303 image = image.cpu().permute(0, 2, 3, 1).float().numpy() 306 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
304 return image 307 return image
305 308
306 def vae_decode(self, latents):
307 if self.use_slicing:
308 decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)]
309 decoded = torch.cat(decoded_slices)
310 return DecoderOutput(sample=decoded)
311 else:
312 return self.vae.decode(latents)
313
314 @torch.no_grad() 309 @torch.no_grad()
315 def __call__( 310 def __call__(
316 self, 311 self,