summaryrefslogtreecommitdiffstats
path: root/models/attention/hook.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/attention/hook.py')
-rw-r--r--models/attention/hook.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/models/attention/hook.py b/models/attention/hook.py
index 903de02..6b5fb68 100644
--- a/models/attention/hook.py
+++ b/models/attention/hook.py
@@ -3,6 +3,7 @@ import torch
3 3
4try: 4try:
5 import xformers.ops 5 import xformers.ops
6
6 xformers._is_functorch_available = True 7 xformers._is_functorch_available = True
7 MEM_EFFICIENT_ATTN = True 8 MEM_EFFICIENT_ATTN = True
8except ImportError: 9except ImportError:
@@ -42,10 +43,10 @@ def register_attention_control(model, controller):
42 return forward 43 return forward
43 44
44 def register_recr(net_, count, place_in_unet): 45 def register_recr(net_, count, place_in_unet):
45 if net_.__class__.__name__ == 'CrossAttention': 46 if net_.__class__.__name__ == "CrossAttention":
46 net_.forward = ca_forward(net_, place_in_unet) 47 net_.forward = ca_forward(net_, place_in_unet)
47 return count + 1 48 return count + 1
48 elif hasattr(net_, 'children'): 49 elif hasattr(net_, "children"):
49 for net__ in net_.children(): 50 for net__ in net_.children():
50 count = register_recr(net__, count, place_in_unet) 51 count = register_recr(net__, count, place_in_unet)
51 return count 52 return count