summaryrefslogtreecommitdiffstats
path: root/models/attention/hook.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
commit8364ce697ddf6117fdd4f7222832d546d63880de (patch)
tree152c99815bbd8b2659d0dabe63c98f63151c97c2 /models/attention/hook.py
parentFix LoRA training with DAdan (diff)
downloadtextual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip
Update
Diffstat (limited to 'models/attention/hook.py')
-rw-r--r--models/attention/hook.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/models/attention/hook.py b/models/attention/hook.py
index 903de02..6b5fb68 100644
--- a/models/attention/hook.py
+++ b/models/attention/hook.py
@@ -3,6 +3,7 @@ import torch
3 3
4try: 4try:
5 import xformers.ops 5 import xformers.ops
6
6 xformers._is_functorch_available = True 7 xformers._is_functorch_available = True
7 MEM_EFFICIENT_ATTN = True 8 MEM_EFFICIENT_ATTN = True
8except ImportError: 9except ImportError:
@@ -42,10 +43,10 @@ def register_attention_control(model, controller):
42 return forward 43 return forward
43 44
44 def register_recr(net_, count, place_in_unet): 45 def register_recr(net_, count, place_in_unet):
45 if net_.__class__.__name__ == 'CrossAttention': 46 if net_.__class__.__name__ == "CrossAttention":
46 net_.forward = ca_forward(net_, place_in_unet) 47 net_.forward = ca_forward(net_, place_in_unet)
47 return count + 1 48 return count + 1
48 elif hasattr(net_, 'children'): 49 elif hasattr(net_, "children"):
49 for net__ in net_.children(): 50 for net__ in net_.children():
50 count = register_recr(net__, count, place_in_unet) 51 count = register_recr(net__, count, place_in_unet)
51 return count 52 return count