summaryrefslogtreecommitdiffstats
path: root/pipelines/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines/util.py')
-rw-r--r--pipelines/util.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/pipelines/util.py b/pipelines/util.py
new file mode 100644
index 0000000..661dbee
--- /dev/null
+++ b/pipelines/util.py
@@ -0,0 +1,9 @@
1import torch
2
3
4def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None:
5 if hasattr(module, "set_use_memory_efficient_attention_xformers"):
6 module.set_use_memory_efficient_attention_xformers(valid)
7
8 for child in module.children():
9 set_use_memory_efficient_attention_xformers(child, valid)