diff options
Diffstat (limited to 'models/attention/hook.py')
| -rw-r--r-- | models/attention/hook.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/models/attention/hook.py b/models/attention/hook.py index 903de02..6b5fb68 100644 --- a/models/attention/hook.py +++ b/models/attention/hook.py | |||
| @@ -3,6 +3,7 @@ import torch | |||
| 3 | 3 | ||
| 4 | try: | 4 | try: |
| 5 | import xformers.ops | 5 | import xformers.ops |
| 6 | |||
| 6 | xformers._is_functorch_available = True | 7 | xformers._is_functorch_available = True |
| 7 | MEM_EFFICIENT_ATTN = True | 8 | MEM_EFFICIENT_ATTN = True |
| 8 | except ImportError: | 9 | except ImportError: |
| @@ -42,10 +43,10 @@ def register_attention_control(model, controller): | |||
| 42 | return forward | 43 | return forward |
| 43 | 44 | ||
| 44 | def register_recr(net_, count, place_in_unet): | 45 | def register_recr(net_, count, place_in_unet): |
| 45 | if net_.__class__.__name__ == 'CrossAttention': | 46 | if net_.__class__.__name__ == "CrossAttention": |
| 46 | net_.forward = ca_forward(net_, place_in_unet) | 47 | net_.forward = ca_forward(net_, place_in_unet) |
| 47 | return count + 1 | 48 | return count + 1 |
| 48 | elif hasattr(net_, 'children'): | 49 | elif hasattr(net_, "children"): |
| 49 | for net__ in net_.children(): | 50 | for net__ in net_.children(): |
| 50 | count = register_recr(net__, count, place_in_unet) | 51 | count = register_recr(net__, count, place_in_unet) |
| 51 | return count | 52 | return count |
