summaryrefslogtreecommitdiffstats
path: root/models/attention/hook.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/attention/hook.py')
-rw-r--r--models/attention/hook.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/models/attention/hook.py b/models/attention/hook.py
new file mode 100644
index 0000000..903de02
--- /dev/null
+++ b/models/attention/hook.py
@@ -0,0 +1,62 @@
1import torch
2
3
4try:
5 import xformers.ops
6 xformers._is_functorch_available = True
7 MEM_EFFICIENT_ATTN = True
8except ImportError:
9 print("[!] Not using xformers memory efficient attention.")
10 MEM_EFFICIENT_ATTN = False
11
12
13def register_attention_control(model, controller):
14 def ca_forward(self, place_in_unet):
15 def forward(x, context=None, mask=None):
16 batch_size, sequence_length, dim = x.shape
17 h = self.heads
18 q = self.to_q(x)
19 is_cross = context is not None
20 context = context if is_cross else x
21 k = self.to_k(context)
22 v = self.to_v(context)
23 q = self.reshape_heads_to_batch_dim(q)
24 k = self.reshape_heads_to_batch_dim(k)
25 v = self.reshape_heads_to_batch_dim(v)
26
27 sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
28
29 if mask is not None:
30 mask = mask.reshape(batch_size, -1)
31 max_neg_value = -torch.finfo(sim.dtype).max
32 mask = mask[:, None, :].repeat(h, 1, 1)
33 sim.masked_fill_(~mask, max_neg_value)
34
35 # attention, what we cannot get enough of
36 attn = sim.softmax(dim=-1)
37 attn = controller(attn, is_cross, place_in_unet)
38 out = torch.einsum("b i j, b j d -> b i d", attn, v)
39 out = self.reshape_batch_dim_to_heads(out)
40 return self.to_out(out)
41
42 return forward
43
44 def register_recr(net_, count, place_in_unet):
45 if net_.__class__.__name__ == 'CrossAttention':
46 net_.forward = ca_forward(net_, place_in_unet)
47 return count + 1
48 elif hasattr(net_, 'children'):
49 for net__ in net_.children():
50 count = register_recr(net__, count, place_in_unet)
51 return count
52
53 cross_att_count = 0
54 sub_nets = model.unet.named_children()
55 for net in sub_nets:
56 if "down" in net[0]:
57 cross_att_count += register_recr(net[1], 0, "down")
58 elif "up" in net[0]:
59 cross_att_count += register_recr(net[1], 0, "up")
60 elif "mid" in net[0]:
61 cross_att_count += register_recr(net[1], 0, "mid")
62 controller.num_att_layers = cross_att_count