summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-30 14:02:35 +0100
committerVolpeon <git@volpeon.ink>2022-11-30 14:02:35 +0100
commit329ad48b307e782b0e23fce80ae9087a4003e73d (patch)
tree0c72434a8d45ae933582064849b43bd7419f7ee8 /pipelines/stable_diffusion
parentAdjusted training to upstream (diff)
downloadtextual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.tar.gz
textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.tar.bz2
textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.zip
Update
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py27
1 files changed, 26 insertions, 1 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 85b0216..c77c4d1 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -20,6 +20,7 @@ from diffusers import (
20 PNDMScheduler, 20 PNDMScheduler,
21) 21)
22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
23from diffusers.models.vae import DecoderOutput
23from diffusers.utils import logging 24from diffusers.utils import logging
24from transformers import CLIPTextModel, CLIPTokenizer 25from transformers import CLIPTextModel, CLIPTokenizer
25from models.clip.prompt import PromptProcessor 26from models.clip.prompt import PromptProcessor
@@ -69,6 +70,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
69 scheduler._internal_dict = FrozenDict(new_config) 70 scheduler._internal_dict = FrozenDict(new_config)
70 71
71 self.prompt_processor = PromptProcessor(tokenizer, text_encoder) 72 self.prompt_processor = PromptProcessor(tokenizer, text_encoder)
73 self.use_slicing = False
72 74
73 self.register_modules( 75 self.register_modules(
74 vae=vae, 76 vae=vae,
@@ -136,6 +138,21 @@ class VlpnStableDiffusion(DiffusionPipeline):
136 if cpu_offloaded_model is not None: 138 if cpu_offloaded_model is not None:
137 cpu_offload(cpu_offloaded_model, device) 139 cpu_offload(cpu_offloaded_model, device)
138 140
141 def enable_vae_slicing(self):
142 r"""
143 Enable sliced VAE decoding.
144 When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
145 steps. This is useful to save some memory and allow larger batch sizes.
146 """
147 self.use_slicing = True
148
149 def disable_vae_slicing(self):
150 r"""
151 Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
152 computing decoding in one step.
153 """
154 self.use_slicing = False
155
139 @property 156 @property
140 def execution_device(self): 157 def execution_device(self):
141 r""" 158 r"""
@@ -280,12 +297,20 @@ class VlpnStableDiffusion(DiffusionPipeline):
280 297
281 def decode_latents(self, latents): 298 def decode_latents(self, latents):
282 latents = 1 / 0.18215 * latents 299 latents = 1 / 0.18215 * latents
283 image = self.vae.decode(latents).sample 300 image = self.vae_decode(latents).sample
284 image = (image / 2 + 0.5).clamp(0, 1) 301 image = (image / 2 + 0.5).clamp(0, 1)
285 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 302 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
286 image = image.cpu().permute(0, 2, 3, 1).float().numpy() 303 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
287 return image 304 return image
288 305
306 def vae_decode(self, latents):
307 if self.use_slicing:
308 decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)]
309 decoded = torch.cat(decoded_slices)
310 return DecoderOutput(sample=decoded)
311 else:
312 return self.vae.decode(latents)
313
289 @torch.no_grad() 314 @torch.no_grad()
290 def __call__( 315 def __call__(
291 self, 316 self,