summaryrefslogtreecommitdiffstats
path: root/models/attention/hook.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/attention/hook.py')
-rw-r--r--models/attention/hook.py63
1 files changed, 0 insertions, 63 deletions
diff --git a/models/attention/hook.py b/models/attention/hook.py
deleted file mode 100644
index 6b5fb68..0000000
--- a/models/attention/hook.py
+++ /dev/null
@@ -1,63 +0,0 @@
1import torch
2
3
4try:
5 import xformers.ops
6
7 xformers._is_functorch_available = True
8 MEM_EFFICIENT_ATTN = True
9except ImportError:
10 print("[!] Not using xformers memory efficient attention.")
11 MEM_EFFICIENT_ATTN = False
12
13
14def register_attention_control(model, controller):
15 def ca_forward(self, place_in_unet):
16 def forward(x, context=None, mask=None):
17 batch_size, sequence_length, dim = x.shape
18 h = self.heads
19 q = self.to_q(x)
20 is_cross = context is not None
21 context = context if is_cross else x
22 k = self.to_k(context)
23 v = self.to_v(context)
24 q = self.reshape_heads_to_batch_dim(q)
25 k = self.reshape_heads_to_batch_dim(k)
26 v = self.reshape_heads_to_batch_dim(v)
27
28 sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
29
30 if mask is not None:
31 mask = mask.reshape(batch_size, -1)
32 max_neg_value = -torch.finfo(sim.dtype).max
33 mask = mask[:, None, :].repeat(h, 1, 1)
34 sim.masked_fill_(~mask, max_neg_value)
35
36 # attention, what we cannot get enough of
37 attn = sim.softmax(dim=-1)
38 attn = controller(attn, is_cross, place_in_unet)
39 out = torch.einsum("b i j, b j d -> b i d", attn, v)
40 out = self.reshape_batch_dim_to_heads(out)
41 return self.to_out(out)
42
43 return forward
44
45 def register_recr(net_, count, place_in_unet):
46 if net_.__class__.__name__ == "CrossAttention":
47 net_.forward = ca_forward(net_, place_in_unet)
48 return count + 1
49 elif hasattr(net_, "children"):
50 for net__ in net_.children():
51 count = register_recr(net__, count, place_in_unet)
52 return count
53
54 cross_att_count = 0
55 sub_nets = model.unet.named_children()
56 for net in sub_nets:
57 if "down" in net[0]:
58 cross_att_count += register_recr(net[1], 0, "down")
59 elif "up" in net[0]:
60 cross_att_count += register_recr(net[1], 0, "up")
61 elif "mid" in net[0]:
62 cross_att_count += register_recr(net[1], 0, "mid")
63 controller.num_att_layers = cross_att_count