diff options
Diffstat (limited to 'models/attention/hook.py')
-rw-r--r-- | models/attention/hook.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/models/attention/hook.py b/models/attention/hook.py index 903de02..6b5fb68 100644 --- a/models/attention/hook.py +++ b/models/attention/hook.py | |||
@@ -3,6 +3,7 @@ import torch | |||
3 | 3 | ||
4 | try: | 4 | try: |
5 | import xformers.ops | 5 | import xformers.ops |
6 | |||
6 | xformers._is_functorch_available = True | 7 | xformers._is_functorch_available = True |
7 | MEM_EFFICIENT_ATTN = True | 8 | MEM_EFFICIENT_ATTN = True |
8 | except ImportError: | 9 | except ImportError: |
@@ -42,10 +43,10 @@ def register_attention_control(model, controller): | |||
42 | return forward | 43 | return forward |
43 | 44 | ||
44 | def register_recr(net_, count, place_in_unet): | 45 | def register_recr(net_, count, place_in_unet): |
45 | if net_.__class__.__name__ == 'CrossAttention': | 46 | if net_.__class__.__name__ == "CrossAttention": |
46 | net_.forward = ca_forward(net_, place_in_unet) | 47 | net_.forward = ca_forward(net_, place_in_unet) |
47 | return count + 1 | 48 | return count + 1 |
48 | elif hasattr(net_, 'children'): | 49 | elif hasattr(net_, "children"): |
49 | for net__ in net_.children(): | 50 | for net__ in net_.children(): |
50 | count = register_recr(net__, count, place_in_unet) | 51 | count = register_recr(net__, count, place_in_unet) |
51 | return count | 52 | return count |