From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- models/attention/hook.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'models/attention/hook.py') diff --git a/models/attention/hook.py b/models/attention/hook.py index 903de02..6b5fb68 100644 --- a/models/attention/hook.py +++ b/models/attention/hook.py @@ -3,6 +3,7 @@ import torch try: import xformers.ops + xformers._is_functorch_available = True MEM_EFFICIENT_ATTN = True except ImportError: @@ -42,10 +43,10 @@ def register_attention_control(model, controller): return forward def register_recr(net_, count, place_in_unet): - if net_.__class__.__name__ == 'CrossAttention': + if net_.__class__.__name__ == "CrossAttention": net_.forward = ca_forward(net_, place_in_unet) return count + 1 - elif hasattr(net_, 'children'): + elif hasattr(net_, "children"): for net__ in net_.children(): count = register_recr(net__, count, place_in_unet) return count -- cgit v1.2.3-54-g00ecf