diff options
| author | Volpeon <git@volpeon.ink> | 2022-11-30 14:02:35 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-11-30 14:02:35 +0100 |
| commit | 329ad48b307e782b0e23fce80ae9087a4003e73d (patch) | |
| tree | 0c72434a8d45ae933582064849b43bd7419f7ee8 /pipelines | |
| parent | Adjusted training to upstream (diff) | |
| download | textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.tar.gz textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.tar.bz2 textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.zip | |
Update
Diffstat (limited to 'pipelines')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 27 |
1 files changed, 26 insertions, 1 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 85b0216..c77c4d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -20,6 +20,7 @@ from diffusers import ( | |||
| 20 | PNDMScheduler, | 20 | PNDMScheduler, |
| 21 | ) | 21 | ) |
| 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 23 | from diffusers.models.vae import DecoderOutput | ||
| 23 | from diffusers.utils import logging | 24 | from diffusers.utils import logging |
| 24 | from transformers import CLIPTextModel, CLIPTokenizer | 25 | from transformers import CLIPTextModel, CLIPTokenizer |
| 25 | from models.clip.prompt import PromptProcessor | 26 | from models.clip.prompt import PromptProcessor |
| @@ -69,6 +70,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 69 | scheduler._internal_dict = FrozenDict(new_config) | 70 | scheduler._internal_dict = FrozenDict(new_config) |
| 70 | 71 | ||
| 71 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | 72 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 73 | self.use_slicing = False | ||
| 72 | 74 | ||
| 73 | self.register_modules( | 75 | self.register_modules( |
| 74 | vae=vae, | 76 | vae=vae, |
| @@ -136,6 +138,21 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 136 | if cpu_offloaded_model is not None: | 138 | if cpu_offloaded_model is not None: |
| 137 | cpu_offload(cpu_offloaded_model, device) | 139 | cpu_offload(cpu_offloaded_model, device) |
| 138 | 140 | ||
| 141 | def enable_vae_slicing(self): | ||
| 142 | r""" | ||
| 143 | Enable sliced VAE decoding. | ||
| 144 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several | ||
| 145 | steps. This is useful to save some memory and allow larger batch sizes. | ||
| 146 | """ | ||
| 147 | self.use_slicing = True | ||
| 148 | |||
| 149 | def disable_vae_slicing(self): | ||
| 150 | r""" | ||
| 151 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to | ||
| 152 | computing decoding in one step. | ||
| 153 | """ | ||
| 154 | self.use_slicing = False | ||
| 155 | |||
| 139 | @property | 156 | @property |
| 140 | def execution_device(self): | 157 | def execution_device(self): |
| 141 | r""" | 158 | r""" |
| @@ -280,12 +297,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 280 | 297 | ||
| 281 | def decode_latents(self, latents): | 298 | def decode_latents(self, latents): |
| 282 | latents = 1 / 0.18215 * latents | 299 | latents = 1 / 0.18215 * latents |
| 283 | image = self.vae.decode(latents).sample | 300 | image = self.vae_decode(latents).sample |
| 284 | image = (image / 2 + 0.5).clamp(0, 1) | 301 | image = (image / 2 + 0.5).clamp(0, 1) |
| 285 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 302 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
| 286 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 303 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| 287 | return image | 304 | return image |
| 288 | 305 | ||
| 306 | def vae_decode(self, latents): | ||
| 307 | if self.use_slicing: | ||
| 308 | decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] | ||
| 309 | decoded = torch.cat(decoded_slices) | ||
| 310 | return DecoderOutput(sample=decoded) | ||
| 311 | else: | ||
| 312 | return self.vae.decode(latents) | ||
| 313 | |||
| 289 | @torch.no_grad() | 314 | @torch.no_grad() |
| 290 | def __call__( | 315 | def __call__( |
| 291 | self, | 316 | self, |
