summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py16
-rw-r--r--pipelines/util.py9
2 files changed, 9 insertions, 16 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 9b51763..f80e951 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -78,22 +78,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
78 scheduler=scheduler, 78 scheduler=scheduler,
79 ) 79 )
80 80
81 def enable_xformers_memory_efficient_attention(self):
82 r"""
83 Enable memory efficient attention as implemented in xformers.
84 When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
85 time. Speed up at training time is not guaranteed.
86 Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
87 is used.
88 """
89 self.unet.set_use_memory_efficient_attention_xformers(True)
90
91 def disable_xformers_memory_efficient_attention(self):
92 r"""
93 Disable memory efficient attention as implemented in xformers.
94 """
95 self.unet.set_use_memory_efficient_attention_xformers(False)
96
97 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 81 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
98 r""" 82 r"""
99 Enable sliced attention computation. 83 Enable sliced attention computation.
diff --git a/pipelines/util.py b/pipelines/util.py
new file mode 100644
index 0000000..661dbee
--- /dev/null
+++ b/pipelines/util.py
@@ -0,0 +1,9 @@
1import torch
2
3
4def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None:
5 if hasattr(module, "set_use_memory_efficient_attention_xformers"):
6 module.set_use_memory_efficient_attention_xformers(valid)
7
8 for child in module.children():
9 set_use_memory_efficient_attention_xformers(child, valid)