diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-24 07:34:30 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-24 07:34:30 +0200 |
| commit | 847ec3b6c43c89ef3649715f86ecfed370b6e442 (patch) | |
| tree | fb0a0312a685e45129763079a0b093f6db7865c4 /models/attention/structured.py | |
| parent | Revert lat; fix skip attribute in dataset (diff) | |
| download | textual-inversion-diff-847ec3b6c43c89ef3649715f86ecfed370b6e442.tar.gz textual-inversion-diff-847ec3b6c43c89ef3649715f86ecfed370b6e442.tar.bz2 textual-inversion-diff-847ec3b6c43c89ef3649715f86ecfed370b6e442.zip | |
Update
Diffstat (limited to 'models/attention/structured.py')
| -rw-r--r-- | models/attention/structured.py | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/models/attention/structured.py b/models/attention/structured.py new file mode 100644 index 0000000..24d889f --- /dev/null +++ b/models/attention/structured.py | |||
| @@ -0,0 +1,132 @@ | |||
| 1 | import torch | ||
| 2 | |||
| 3 | from .control import AttentionControl | ||
| 4 | |||
| 5 | |||
| 6 | class StructuredAttentionControl(AttentionControl): | ||
| 7 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
| 8 | return attn | ||
| 9 | |||
| 10 | def forward(self, x, context=None, mask=None): | ||
| 11 | h = self.heads | ||
| 12 | |||
| 13 | q = self.to_q(x) | ||
| 14 | |||
| 15 | if isinstance(context, list): | ||
| 16 | if self.struct_attn: | ||
| 17 | out = self.struct_qkv(q, context, mask) | ||
| 18 | else: | ||
| 19 | context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context | ||
| 20 | out = self.normal_qkv(q, context, mask) | ||
| 21 | else: | ||
| 22 | context = default(context, x) | ||
| 23 | out = self.normal_qkv(q, context, mask) | ||
| 24 | |||
| 25 | return self.to_out(out) | ||
| 26 | |||
| 27 | def struct_qkv(self, q, context, mask): | ||
| 28 | """ | ||
| 29 | context: list of [uc, list of conditional context] | ||
| 30 | """ | ||
| 31 | uc_context = context[0] | ||
| 32 | context_k, context_v = context[1]['k'], context[1]['v'] | ||
| 33 | |||
| 34 | if isinstance(context_k, list) and isinstance(context_v, list): | ||
| 35 | out = self.multi_qkv(q, uc_context, context_k, context_v, mask) | ||
| 36 | elif isinstance(context_k, torch.Tensor) and isinstance(context_v, torch.Tensor): | ||
| 37 | out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) | ||
| 38 | else: | ||
| 39 | raise NotImplementedError | ||
| 40 | |||
| 41 | return out | ||
| 42 | |||
| 43 | def multi_qkv(self, q, uc_context, context_k, context_v, mask): | ||
| 44 | h = self.heads | ||
| 45 | |||
| 46 | assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0) | ||
| 47 | true_bs = uc_context.size(0) * h | ||
| 48 | |||
| 49 | k_uc, v_uc = self.get_kv(uc_context) | ||
| 50 | k_c = [self.to_k(c_k) for c_k in context_k] | ||
| 51 | v_c = [self.to_v(c_v) for c_v in context_v] | ||
| 52 | |||
| 53 | q = rearrange(q, 'b n (h d) -> (b h) n d', h=h) | ||
| 54 | |||
| 55 | k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h) | ||
| 56 | v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h) | ||
| 57 | |||
| 58 | k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point | ||
| 59 | v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c] | ||
| 60 | |||
| 61 | # get composition | ||
| 62 | sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale | ||
| 63 | sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c] | ||
| 64 | |||
| 65 | attn_uc = sim_uc.softmax(dim=-1) | ||
| 66 | attn_c = [sim.softmax(dim=-1) for sim in sim_c] | ||
| 67 | |||
| 68 | # get uc output | ||
| 69 | out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc) | ||
| 70 | |||
| 71 | # get c output | ||
| 72 | if len(v_c) == 1: | ||
| 73 | out_c_collect = [] | ||
| 74 | for attn in attn_c: | ||
| 75 | for v in v_c: | ||
| 76 | out_c_collect.append(einsum('b i j, b j d -> b i d', attn, v)) | ||
| 77 | out_c = sum(out_c_collect) / len(out_c_collect) | ||
| 78 | else: | ||
| 79 | out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c) | ||
| 80 | |||
| 81 | out = torch.cat([out_uc, out_c], dim=0) | ||
| 82 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | ||
| 83 | |||
| 84 | return out | ||
| 85 | |||
| 86 | def normal_qkv(self, q, context, mask): | ||
| 87 | h = self.heads | ||
| 88 | k = self.to_k(context) | ||
| 89 | v = self.to_v(context) | ||
| 90 | |||
| 91 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | ||
| 92 | |||
| 93 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | ||
| 94 | |||
| 95 | if exists(mask): | ||
| 96 | mask = rearrange(mask, 'b ... -> b (...)') | ||
| 97 | max_neg_value = -torch.finfo(sim.dtype).max | ||
| 98 | mask = repeat(mask, 'b j -> (b h) () j', h=h) | ||
| 99 | sim.masked_fill_(~mask, max_neg_value) | ||
| 100 | |||
| 101 | # attention, what we cannot get enough of | ||
| 102 | attn = sim.softmax(dim=-1) | ||
| 103 | |||
| 104 | out = einsum('b i j, b j d -> b i d', attn, v) | ||
| 105 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | ||
| 106 | |||
| 107 | return out | ||
| 108 | |||
| 109 | def heterogeous_qkv(self, q, uc_context, context_k, context_v, mask): | ||
| 110 | h = self.heads | ||
| 111 | k = self.to_k(torch.cat([uc_context, context_k], dim=0)) | ||
| 112 | v = self.to_v(torch.cat([uc_context, context_v], dim=0)) | ||
| 113 | |||
| 114 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | ||
| 115 | |||
| 116 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | ||
| 117 | |||
| 118 | if exists(mask): | ||
| 119 | mask = rearrange(mask, 'b ... -> b (...)') | ||
| 120 | max_neg_value = -torch.finfo(sim.dtype).max | ||
| 121 | mask = repeat(mask, 'b j -> (b h) () j', h=h) | ||
| 122 | sim.masked_fill_(~mask, max_neg_value) | ||
| 123 | |||
| 124 | # attention, what we cannot get enough of | ||
| 125 | attn = sim.softmax(dim=-1) | ||
| 126 | |||
| 127 | out = einsum('b i j, b j d -> b i d', attn, v) | ||
| 128 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | ||
| 129 | return out | ||
| 130 | |||
| 131 | def get_kv(self, context): | ||
| 132 | return self.to_k(context), self.to_v(context) | ||
