blob: 661dbeebb36d58e4db4748e4f4ef9387feb6f3fe (
plain) (
blame)
1
2
3
4
5
6
7
8
9
|
import torch
def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None:
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children():
set_use_memory_efficient_attention_xformers(child, valid)
|