summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py36
1 files changed, 34 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index cd5ae7e..36942f0 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -7,6 +7,7 @@ import torch
7import PIL 7import PIL
8 8
9from diffusers.configuration_utils import FrozenDict 9from diffusers.configuration_utils import FrozenDict
10from diffusers.utils import is_accelerate_available
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel 11from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 12from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 13from diffusers.utils import logging
@@ -61,13 +62,27 @@ class VlpnStableDiffusion(DiffusionPipeline):
61 scheduler=scheduler, 62 scheduler=scheduler,
62 ) 63 )
63 64
65 def enable_xformers_memory_efficient_attention(self):
66 r"""
67 Enable memory efficient attention as implemented in xformers.
68 When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
69 time. Speed up at training time is not guaranteed.
70 Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
71 is used.
72 """
73 self.unet.set_use_memory_efficient_attention_xformers(True)
74
75 def disable_xformers_memory_efficient_attention(self):
76 r"""
77 Disable memory efficient attention as implemented in xformers.
78 """
79 self.unet.set_use_memory_efficient_attention_xformers(False)
80
64 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 81 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
65 r""" 82 r"""
66 Enable sliced attention computation. 83 Enable sliced attention computation.
67
68 When this option is enabled, the attention module will split the input tensor in slices, to compute attention 84 When this option is enabled, the attention module will split the input tensor in slices, to compute attention
69 in several steps. This is useful to save some memory in exchange for a small speed decrease. 85 in several steps. This is useful to save some memory in exchange for a small speed decrease.
70
71 Args: 86 Args:
72 slice_size (`str` or `int`, *optional*, defaults to `"auto"`): 87 slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
73 When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 88 When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
@@ -88,6 +103,23 @@ class VlpnStableDiffusion(DiffusionPipeline):
88 # set slice_size = `None` to disable `attention slicing` 103 # set slice_size = `None` to disable `attention slicing`
89 self.enable_attention_slicing(None) 104 self.enable_attention_slicing(None)
90 105
106 def enable_sequential_cpu_offload(self):
107 r"""
108 Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
109 text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
110 `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
111 """
112 if is_accelerate_available():
113 from accelerate import cpu_offload
114 else:
115 raise ImportError("Please install accelerate via `pip install accelerate`")
116
117 device = torch.device("cuda")
118
119 for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
120 if cpu_offloaded_model is not None:
121 cpu_offload(cpu_offloaded_model, device)
122
91 @torch.no_grad() 123 @torch.no_grad()
92 def __call__( 124 def __call__(
93 self, 125 self,