diff options
author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /models/attention/hook.py | |
parent | Fix LoRA training with DAdan (diff) | |
download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip |
Update
Diffstat (limited to 'models/attention/hook.py')
-rw-r--r-- | models/attention/hook.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/models/attention/hook.py b/models/attention/hook.py index 903de02..6b5fb68 100644 --- a/models/attention/hook.py +++ b/models/attention/hook.py | |||
@@ -3,6 +3,7 @@ import torch | |||
3 | 3 | ||
4 | try: | 4 | try: |
5 | import xformers.ops | 5 | import xformers.ops |
6 | |||
6 | xformers._is_functorch_available = True | 7 | xformers._is_functorch_available = True |
7 | MEM_EFFICIENT_ATTN = True | 8 | MEM_EFFICIENT_ATTN = True |
8 | except ImportError: | 9 | except ImportError: |
@@ -42,10 +43,10 @@ def register_attention_control(model, controller): | |||
42 | return forward | 43 | return forward |
43 | 44 | ||
44 | def register_recr(net_, count, place_in_unet): | 45 | def register_recr(net_, count, place_in_unet): |
45 | if net_.__class__.__name__ == 'CrossAttention': | 46 | if net_.__class__.__name__ == "CrossAttention": |
46 | net_.forward = ca_forward(net_, place_in_unet) | 47 | net_.forward = ca_forward(net_, place_in_unet) |
47 | return count + 1 | 48 | return count + 1 |
48 | elif hasattr(net_, 'children'): | 49 | elif hasattr(net_, "children"): |
49 | for net__ in net_.children(): | 50 | for net__ in net_.children(): |
50 | count = register_recr(net__, count, place_in_unet) | 51 | count = register_recr(net__, count, place_in_unet) |
51 | return count | 52 | return count |