diff options
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 | ||||
-rw-r--r-- | pipelines/util.py | 9 |
2 files changed, 9 insertions, 16 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 9b51763..f80e951 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -78,22 +78,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
78 | scheduler=scheduler, | 78 | scheduler=scheduler, |
79 | ) | 79 | ) |
80 | 80 | ||
81 | def enable_xformers_memory_efficient_attention(self): | ||
82 | r""" | ||
83 | Enable memory efficient attention as implemented in xformers. | ||
84 | When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference | ||
85 | time. Speed up at training time is not guaranteed. | ||
86 | Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention | ||
87 | is used. | ||
88 | """ | ||
89 | self.unet.set_use_memory_efficient_attention_xformers(True) | ||
90 | |||
91 | def disable_xformers_memory_efficient_attention(self): | ||
92 | r""" | ||
93 | Disable memory efficient attention as implemented in xformers. | ||
94 | """ | ||
95 | self.unet.set_use_memory_efficient_attention_xformers(False) | ||
96 | |||
97 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 81 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
98 | r""" | 82 | r""" |
99 | Enable sliced attention computation. | 83 | Enable sliced attention computation. |
diff --git a/pipelines/util.py b/pipelines/util.py new file mode 100644 index 0000000..661dbee --- /dev/null +++ b/pipelines/util.py | |||
@@ -0,0 +1,9 @@ | |||
1 | import torch | ||
2 | |||
3 | |||
4 | def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None: | ||
5 | if hasattr(module, "set_use_memory_efficient_attention_xformers"): | ||
6 | module.set_use_memory_efficient_attention_xformers(valid) | ||
7 | |||
8 | for child in module.children(): | ||
9 | set_use_memory_efficient_attention_xformers(child, valid) | ||