diff options
Diffstat (limited to 'pipelines/util.py')
-rw-r--r-- | pipelines/util.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/pipelines/util.py b/pipelines/util.py new file mode 100644 index 0000000..661dbee --- /dev/null +++ b/pipelines/util.py | |||
@@ -0,0 +1,9 @@ | |||
1 | import torch | ||
2 | |||
3 | |||
4 | def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None: | ||
5 | if hasattr(module, "set_use_memory_efficient_attention_xformers"): | ||
6 | module.set_use_memory_efficient_attention_xformers(valid) | ||
7 | |||
8 | for child in module.children(): | ||
9 | set_use_memory_efficient_attention_xformers(child, valid) | ||