summaryrefslogtreecommitdiffstats
path: root/pipelines/util.py
blob: 661dbeebb36d58e4db4748e4f4ef9387feb6f3fe (plain) (blame)
1
2
3
4
5
6
7
8
9
import torch


def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None:
    if hasattr(module, "set_use_memory_efficient_attention_xformers"):
        module.set_use_memory_efficient_attention_xformers(valid)

    for child in module.children():
        set_use_memory_efficient_attention_xformers(child, valid)