1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
|
import torch
from .control import AttentionControl
class StructuredAttentionControl(AttentionControl):
def forward(self, attn, is_cross: bool, place_in_unet: str):
return attn
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
if isinstance(context, list):
if self.struct_attn:
out = self.struct_qkv(q, context, mask)
else:
context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context
out = self.normal_qkv(q, context, mask)
else:
context = default(context, x)
out = self.normal_qkv(q, context, mask)
return self.to_out(out)
def struct_qkv(self, q, context, mask):
"""
context: list of [uc, list of conditional context]
"""
uc_context = context[0]
context_k, context_v = context[1]['k'], context[1]['v']
if isinstance(context_k, list) and isinstance(context_v, list):
out = self.multi_qkv(q, uc_context, context_k, context_v, mask)
elif isinstance(context_k, torch.Tensor) and isinstance(context_v, torch.Tensor):
out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask)
else:
raise NotImplementedError
return out
def multi_qkv(self, q, uc_context, context_k, context_v, mask):
h = self.heads
assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0)
true_bs = uc_context.size(0) * h
k_uc, v_uc = self.get_kv(uc_context)
k_c = [self.to_k(c_k) for c_k in context_k]
v_c = [self.to_v(c_v) for c_v in context_v]
q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h)
v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h)
k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point
v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c]
# get composition
sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale
sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c]
attn_uc = sim_uc.softmax(dim=-1)
attn_c = [sim.softmax(dim=-1) for sim in sim_c]
# get uc output
out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc)
# get c output
if len(v_c) == 1:
out_c_collect = []
for attn in attn_c:
for v in v_c:
out_c_collect.append(einsum('b i j, b j d -> b i d', attn, v))
out_c = sum(out_c_collect) / len(out_c_collect)
else:
out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c)
out = torch.cat([out_uc, out_c], dim=0)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return out
def normal_qkv(self, q, context, mask):
h = self.heads
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return out
def heterogeous_qkv(self, q, uc_context, context_k, context_v, mask):
h = self.heads
k = self.to_k(torch.cat([uc_context, context_k], dim=0))
v = self.to_v(torch.cat([uc_context, context_v], dim=0))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return out
def get_kv(self, context):
return self.to_k(context), self.to_v(context)
|