diff options
author | Volpeon <git@volpeon.ink> | 2022-11-30 14:02:35 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-30 14:02:35 +0100 |
commit | 329ad48b307e782b0e23fce80ae9087a4003e73d (patch) | |
tree | 0c72434a8d45ae933582064849b43bd7419f7ee8 /pipelines | |
parent | Adjusted training to upstream (diff) | |
download | textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.tar.gz textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.tar.bz2 textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.zip |
Update
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 27 |
1 files changed, 26 insertions, 1 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 85b0216..c77c4d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -20,6 +20,7 @@ from diffusers import ( | |||
20 | PNDMScheduler, | 20 | PNDMScheduler, |
21 | ) | 21 | ) |
22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
23 | from diffusers.models.vae import DecoderOutput | ||
23 | from diffusers.utils import logging | 24 | from diffusers.utils import logging |
24 | from transformers import CLIPTextModel, CLIPTokenizer | 25 | from transformers import CLIPTextModel, CLIPTokenizer |
25 | from models.clip.prompt import PromptProcessor | 26 | from models.clip.prompt import PromptProcessor |
@@ -69,6 +70,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
69 | scheduler._internal_dict = FrozenDict(new_config) | 70 | scheduler._internal_dict = FrozenDict(new_config) |
70 | 71 | ||
71 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | 72 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) |
73 | self.use_slicing = False | ||
72 | 74 | ||
73 | self.register_modules( | 75 | self.register_modules( |
74 | vae=vae, | 76 | vae=vae, |
@@ -136,6 +138,21 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
136 | if cpu_offloaded_model is not None: | 138 | if cpu_offloaded_model is not None: |
137 | cpu_offload(cpu_offloaded_model, device) | 139 | cpu_offload(cpu_offloaded_model, device) |
138 | 140 | ||
141 | def enable_vae_slicing(self): | ||
142 | r""" | ||
143 | Enable sliced VAE decoding. | ||
144 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several | ||
145 | steps. This is useful to save some memory and allow larger batch sizes. | ||
146 | """ | ||
147 | self.use_slicing = True | ||
148 | |||
149 | def disable_vae_slicing(self): | ||
150 | r""" | ||
151 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to | ||
152 | computing decoding in one step. | ||
153 | """ | ||
154 | self.use_slicing = False | ||
155 | |||
139 | @property | 156 | @property |
140 | def execution_device(self): | 157 | def execution_device(self): |
141 | r""" | 158 | r""" |
@@ -280,12 +297,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
280 | 297 | ||
281 | def decode_latents(self, latents): | 298 | def decode_latents(self, latents): |
282 | latents = 1 / 0.18215 * latents | 299 | latents = 1 / 0.18215 * latents |
283 | image = self.vae.decode(latents).sample | 300 | image = self.vae_decode(latents).sample |
284 | image = (image / 2 + 0.5).clamp(0, 1) | 301 | image = (image / 2 + 0.5).clamp(0, 1) |
285 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 302 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
286 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 303 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
287 | return image | 304 | return image |
288 | 305 | ||
306 | def vae_decode(self, latents): | ||
307 | if self.use_slicing: | ||
308 | decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] | ||
309 | decoded = torch.cat(decoded_slices) | ||
310 | return DecoderOutput(sample=decoded) | ||
311 | else: | ||
312 | return self.vae.decode(latents) | ||
313 | |||
289 | @torch.no_grad() | 314 | @torch.no_grad() |
290 | def __call__( | 315 | def __call__( |
291 | self, | 316 | self, |