diff options
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 36 |
1 files changed, 34 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cd5ae7e..36942f0 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -7,6 +7,7 @@ import torch | |||
7 | import PIL | 7 | import PIL |
8 | 8 | ||
9 | from diffusers.configuration_utils import FrozenDict | 9 | from diffusers.configuration_utils import FrozenDict |
10 | from diffusers.utils import is_accelerate_available | ||
10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 11 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
12 | from diffusers.utils import logging | 13 | from diffusers.utils import logging |
@@ -61,13 +62,27 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
61 | scheduler=scheduler, | 62 | scheduler=scheduler, |
62 | ) | 63 | ) |
63 | 64 | ||
65 | def enable_xformers_memory_efficient_attention(self): | ||
66 | r""" | ||
67 | Enable memory efficient attention as implemented in xformers. | ||
68 | When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference | ||
69 | time. Speed up at training time is not guaranteed. | ||
70 | Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention | ||
71 | is used. | ||
72 | """ | ||
73 | self.unet.set_use_memory_efficient_attention_xformers(True) | ||
74 | |||
75 | def disable_xformers_memory_efficient_attention(self): | ||
76 | r""" | ||
77 | Disable memory efficient attention as implemented in xformers. | ||
78 | """ | ||
79 | self.unet.set_use_memory_efficient_attention_xformers(False) | ||
80 | |||
64 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 81 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
65 | r""" | 82 | r""" |
66 | Enable sliced attention computation. | 83 | Enable sliced attention computation. |
67 | |||
68 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention | 84 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention |
69 | in several steps. This is useful to save some memory in exchange for a small speed decrease. | 85 | in several steps. This is useful to save some memory in exchange for a small speed decrease. |
70 | |||
71 | Args: | 86 | Args: |
72 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): | 87 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): |
73 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If | 88 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If |
@@ -88,6 +103,23 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
88 | # set slice_size = `None` to disable `attention slicing` | 103 | # set slice_size = `None` to disable `attention slicing` |
89 | self.enable_attention_slicing(None) | 104 | self.enable_attention_slicing(None) |
90 | 105 | ||
106 | def enable_sequential_cpu_offload(self): | ||
107 | r""" | ||
108 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, | ||
109 | text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a | ||
110 | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. | ||
111 | """ | ||
112 | if is_accelerate_available(): | ||
113 | from accelerate import cpu_offload | ||
114 | else: | ||
115 | raise ImportError("Please install accelerate via `pip install accelerate`") | ||
116 | |||
117 | device = torch.device("cuda") | ||
118 | |||
119 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: | ||
120 | if cpu_offloaded_model is not None: | ||
121 | cpu_offload(cpu_offloaded_model, device) | ||
122 | |||
91 | @torch.no_grad() | 123 | @torch.no_grad() |
92 | def __call__( | 124 | def __call__( |
93 | self, | 125 | self, |