1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
|
import torch
try:
import xformers.ops
xformers._is_functorch_available = True
MEM_EFFICIENT_ATTN = True
except ImportError:
print("[!] Not using xformers memory efficient attention.")
MEM_EFFICIENT_ATTN = False
def register_attention_control(model, controller):
def ca_forward(self, place_in_unet):
def forward(x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
q = self.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = self.to_k(context)
v = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if mask is not None:
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
attn = controller(attn, is_cross, place_in_unet)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)
return self.to_out(out)
return forward
def register_recr(net_, count, place_in_unet):
if net_.__class__.__name__ == 'CrossAttention':
net_.forward = ca_forward(net_, place_in_unet)
return count + 1
elif hasattr(net_, 'children'):
for net__ in net_.children():
count = register_recr(net__, count, place_in_unet)
return count
cross_att_count = 0
sub_nets = model.unet.named_children()
for net in sub_nets:
if "down" in net[0]:
cross_att_count += register_recr(net[1], 0, "down")
elif "up" in net[0]:
cross_att_count += register_recr(net[1], 0, "up")
elif "mid" in net[0]:
cross_att_count += register_recr(net[1], 0, "mid")
controller.num_att_layers = cross_att_count
|