diff options
author | Volpeon <git@volpeon.ink> | 2022-12-01 13:45:21 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-01 13:45:21 +0100 |
commit | c5cc1318c2a7597fe62d3379e50187d0b0f22538 (patch) | |
tree | 66f2b939c5498849692836e368fce481e7eaf3d2 /pipelines/stable_diffusion | |
parent | Update (diff) | |
download | textual-inversion-diff-c5cc1318c2a7597fe62d3379e50187d0b0f22538.tar.gz textual-inversion-diff-c5cc1318c2a7597fe62d3379e50187d0b0f22538.tar.bz2 textual-inversion-diff-c5cc1318c2a7597fe62d3379e50187d0b0f22538.zip |
Update
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 27 |
1 files changed, 11 insertions, 16 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index c77c4d1..9b51763 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -20,7 +20,6 @@ from diffusers import ( | |||
20 | PNDMScheduler, | 20 | PNDMScheduler, |
21 | ) | 21 | ) |
22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
23 | from diffusers.models.vae import DecoderOutput | ||
24 | from diffusers.utils import logging | 23 | from diffusers.utils import logging |
25 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
26 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.prompt import PromptProcessor |
@@ -70,7 +69,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
70 | scheduler._internal_dict = FrozenDict(new_config) | 69 | scheduler._internal_dict = FrozenDict(new_config) |
71 | 70 | ||
72 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | 71 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) |
73 | self.use_slicing = False | ||
74 | 72 | ||
75 | self.register_modules( | 73 | self.register_modules( |
76 | vae=vae, | 74 | vae=vae, |
@@ -108,9 +106,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
108 | `attention_head_dim` must be a multiple of `slice_size`. | 106 | `attention_head_dim` must be a multiple of `slice_size`. |
109 | """ | 107 | """ |
110 | if slice_size == "auto": | 108 | if slice_size == "auto": |
111 | # half the attention head size is usually a good trade-off between | 109 | if isinstance(self.unet.config.attention_head_dim, int): |
112 | # speed and memory | 110 | # half the attention head size is usually a good trade-off between |
113 | slice_size = self.unet.config.attention_head_dim // 2 | 111 | # speed and memory |
112 | slice_size = self.unet.config.attention_head_dim // 2 | ||
113 | else: | ||
114 | # if `attention_head_dim` is a list, take the smallest head size | ||
115 | slice_size = min(self.unet.config.attention_head_dim) | ||
116 | |||
114 | self.unet.set_attention_slice(slice_size) | 117 | self.unet.set_attention_slice(slice_size) |
115 | 118 | ||
116 | def disable_attention_slicing(self): | 119 | def disable_attention_slicing(self): |
@@ -144,14 +147,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
144 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several | 147 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several |
145 | steps. This is useful to save some memory and allow larger batch sizes. | 148 | steps. This is useful to save some memory and allow larger batch sizes. |
146 | """ | 149 | """ |
147 | self.use_slicing = True | 150 | self.vae.enable_slicing() |
148 | 151 | ||
149 | def disable_vae_slicing(self): | 152 | def disable_vae_slicing(self): |
150 | r""" | 153 | r""" |
151 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to | 154 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to |
152 | computing decoding in one step. | 155 | computing decoding in one step. |
153 | """ | 156 | """ |
154 | self.use_slicing = False | 157 | self.vae.disable_slicing() |
155 | 158 | ||
156 | @property | 159 | @property |
157 | def execution_device(self): | 160 | def execution_device(self): |
@@ -297,20 +300,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
297 | 300 | ||
298 | def decode_latents(self, latents): | 301 | def decode_latents(self, latents): |
299 | latents = 1 / 0.18215 * latents | 302 | latents = 1 / 0.18215 * latents |
300 | image = self.vae_decode(latents).sample | 303 | image = self.vae.decode(latents).sample |
301 | image = (image / 2 + 0.5).clamp(0, 1) | 304 | image = (image / 2 + 0.5).clamp(0, 1) |
302 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 305 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
303 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 306 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
304 | return image | 307 | return image |
305 | 308 | ||
306 | def vae_decode(self, latents): | ||
307 | if self.use_slicing: | ||
308 | decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] | ||
309 | decoded = torch.cat(decoded_slices) | ||
310 | return DecoderOutput(sample=decoded) | ||
311 | else: | ||
312 | return self.vae.decode(latents) | ||
313 | |||
314 | @torch.no_grad() | 309 | @torch.no_grad() |
315 | def __call__( | 310 | def __call__( |
316 | self, | 311 | self, |