diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/attention/control.py | 168 | ||||
-rw-r--r-- | models/attention/hook.py | 62 | ||||
-rw-r--r-- | models/attention/structured.py | 132 |
3 files changed, 362 insertions, 0 deletions
diff --git a/models/attention/control.py b/models/attention/control.py new file mode 100644 index 0000000..248bd9f --- /dev/null +++ b/models/attention/control.py | |||
@@ -0,0 +1,168 @@ | |||
1 | import torch | ||
2 | import abc | ||
3 | |||
4 | |||
5 | class AttentionControl(abc.ABC): | ||
6 | def step_callback(self, x_t): | ||
7 | return x_t | ||
8 | |||
9 | def between_steps(self): | ||
10 | return | ||
11 | |||
12 | @property | ||
13 | def num_uncond_att_layers(self): | ||
14 | return self.num_att_layers if LOW_RESOURCE else 0 | ||
15 | |||
16 | @abc.abstractmethod | ||
17 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
18 | raise NotImplementedError | ||
19 | |||
20 | def __call__(self, attn, is_cross: bool, place_in_unet: str): | ||
21 | if self.cur_att_layer >= self.num_uncond_att_layers: | ||
22 | if LOW_RESOURCE: | ||
23 | attn = self.forward(attn, is_cross, place_in_unet) | ||
24 | else: | ||
25 | h = attn.shape[0] | ||
26 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) | ||
27 | self.cur_att_layer += 1 | ||
28 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: | ||
29 | self.cur_att_layer = 0 | ||
30 | self.cur_step += 1 | ||
31 | self.between_steps() | ||
32 | return attn | ||
33 | |||
34 | def reset(self): | ||
35 | self.cur_step = 0 | ||
36 | self.cur_att_layer = 0 | ||
37 | |||
38 | def __init__(self): | ||
39 | self.cur_step = 0 | ||
40 | self.num_att_layers = -1 | ||
41 | self.cur_att_layer = 0 | ||
42 | |||
43 | |||
44 | class EmptyControl(AttentionControl): | ||
45 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
46 | return attn | ||
47 | |||
48 | |||
49 | class AttentionStore(AttentionControl): | ||
50 | @staticmethod | ||
51 | def get_empty_store(): | ||
52 | return {"down_cross": [], "mid_cross": [], "up_cross": [], | ||
53 | "down_self": [], "mid_self": [], "up_self": []} | ||
54 | |||
55 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
56 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | ||
57 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead | ||
58 | self.step_store[key].append(attn) | ||
59 | return attn | ||
60 | |||
61 | def between_steps(self): | ||
62 | if len(self.attention_store) == 0: | ||
63 | self.attention_store = self.step_store | ||
64 | else: | ||
65 | for key in self.attention_store: | ||
66 | for i in range(len(self.attention_store[key])): | ||
67 | self.attention_store[key][i] += self.step_store[key][i] | ||
68 | self.step_store = self.get_empty_store() | ||
69 | |||
70 | def get_average_attention(self): | ||
71 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] | ||
72 | for key in self.attention_store} | ||
73 | return average_attention | ||
74 | |||
75 | def reset(self): | ||
76 | super(AttentionStore, self).reset() | ||
77 | self.step_store = self.get_empty_store() | ||
78 | self.attention_store = {} | ||
79 | |||
80 | def __init__(self): | ||
81 | super(AttentionStore, self).__init__() | ||
82 | self.step_store = self.get_empty_store() | ||
83 | self.attention_store = {} | ||
84 | |||
85 | |||
86 | class AttentionControlEdit(AttentionStore, abc.ABC): | ||
87 | def step_callback(self, x_t): | ||
88 | if self.local_blend is not None: | ||
89 | x_t = self.local_blend(x_t, self.attention_store) | ||
90 | return x_t | ||
91 | |||
92 | def replace_self_attention(self, attn_base, att_replace): | ||
93 | if att_replace.shape[2] <= 16 ** 2: | ||
94 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) | ||
95 | else: | ||
96 | return att_replace | ||
97 | |||
98 | @abc.abstractmethod | ||
99 | def replace_cross_attention(self, attn_base, att_replace): | ||
100 | raise NotImplementedError | ||
101 | |||
102 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
103 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) | ||
104 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): | ||
105 | h = attn.shape[0] // (self.batch_size) | ||
106 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) | ||
107 | attn_base, attn_repalce = attn[0], attn[1:] | ||
108 | if is_cross: | ||
109 | alpha_words = self.cross_replace_alpha[self.cur_step] | ||
110 | attn_repalce_new = self.replace_cross_attention( | ||
111 | attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce | ||
112 | attn[1:] = attn_repalce_new | ||
113 | else: | ||
114 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) | ||
115 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) | ||
116 | return attn | ||
117 | |||
118 | def __init__(self, prompts, num_steps: int, | ||
119 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], | ||
120 | self_replace_steps: Union[float, Tuple[float, float]], | ||
121 | local_blend: Optional[LocalBlend]): | ||
122 | super(AttentionControlEdit, self).__init__() | ||
123 | self.batch_size = len(prompts) | ||
124 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( | ||
125 | prompts, num_steps, cross_replace_steps, tokenizer).to(device) | ||
126 | if type(self_replace_steps) is float: | ||
127 | self_replace_steps = 0, self_replace_steps | ||
128 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) | ||
129 | self.local_blend = local_blend | ||
130 | |||
131 | |||
132 | class AttentionReplace(AttentionControlEdit): | ||
133 | def replace_cross_attention(self, attn_base, att_replace): | ||
134 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) | ||
135 | |||
136 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, | ||
137 | local_blend: Optional[LocalBlend] = None): | ||
138 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) | ||
139 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) | ||
140 | |||
141 | |||
142 | class AttentionRefine(AttentionControlEdit): | ||
143 | def replace_cross_attention(self, attn_base, att_replace): | ||
144 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) | ||
145 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) | ||
146 | return attn_replace | ||
147 | |||
148 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, | ||
149 | local_blend: Optional[LocalBlend] = None): | ||
150 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) | ||
151 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) | ||
152 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) | ||
153 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) | ||
154 | |||
155 | |||
156 | class AttentionReweight(AttentionControlEdit): | ||
157 | def replace_cross_attention(self, attn_base, att_replace): | ||
158 | if self.prev_controller is not None: | ||
159 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) | ||
160 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] | ||
161 | return attn_replace | ||
162 | |||
163 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, | ||
164 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): | ||
165 | super(AttentionReweight, self).__init__(prompts, num_steps, | ||
166 | cross_replace_steps, self_replace_steps, local_blend) | ||
167 | self.equalizer = equalizer.to(device) | ||
168 | self.prev_controller = controller | ||
diff --git a/models/attention/hook.py b/models/attention/hook.py new file mode 100644 index 0000000..903de02 --- /dev/null +++ b/models/attention/hook.py | |||
@@ -0,0 +1,62 @@ | |||
1 | import torch | ||
2 | |||
3 | |||
4 | try: | ||
5 | import xformers.ops | ||
6 | xformers._is_functorch_available = True | ||
7 | MEM_EFFICIENT_ATTN = True | ||
8 | except ImportError: | ||
9 | print("[!] Not using xformers memory efficient attention.") | ||
10 | MEM_EFFICIENT_ATTN = False | ||
11 | |||
12 | |||
13 | def register_attention_control(model, controller): | ||
14 | def ca_forward(self, place_in_unet): | ||
15 | def forward(x, context=None, mask=None): | ||
16 | batch_size, sequence_length, dim = x.shape | ||
17 | h = self.heads | ||
18 | q = self.to_q(x) | ||
19 | is_cross = context is not None | ||
20 | context = context if is_cross else x | ||
21 | k = self.to_k(context) | ||
22 | v = self.to_v(context) | ||
23 | q = self.reshape_heads_to_batch_dim(q) | ||
24 | k = self.reshape_heads_to_batch_dim(k) | ||
25 | v = self.reshape_heads_to_batch_dim(v) | ||
26 | |||
27 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale | ||
28 | |||
29 | if mask is not None: | ||
30 | mask = mask.reshape(batch_size, -1) | ||
31 | max_neg_value = -torch.finfo(sim.dtype).max | ||
32 | mask = mask[:, None, :].repeat(h, 1, 1) | ||
33 | sim.masked_fill_(~mask, max_neg_value) | ||
34 | |||
35 | # attention, what we cannot get enough of | ||
36 | attn = sim.softmax(dim=-1) | ||
37 | attn = controller(attn, is_cross, place_in_unet) | ||
38 | out = torch.einsum("b i j, b j d -> b i d", attn, v) | ||
39 | out = self.reshape_batch_dim_to_heads(out) | ||
40 | return self.to_out(out) | ||
41 | |||
42 | return forward | ||
43 | |||
44 | def register_recr(net_, count, place_in_unet): | ||
45 | if net_.__class__.__name__ == 'CrossAttention': | ||
46 | net_.forward = ca_forward(net_, place_in_unet) | ||
47 | return count + 1 | ||
48 | elif hasattr(net_, 'children'): | ||
49 | for net__ in net_.children(): | ||
50 | count = register_recr(net__, count, place_in_unet) | ||
51 | return count | ||
52 | |||
53 | cross_att_count = 0 | ||
54 | sub_nets = model.unet.named_children() | ||
55 | for net in sub_nets: | ||
56 | if "down" in net[0]: | ||
57 | cross_att_count += register_recr(net[1], 0, "down") | ||
58 | elif "up" in net[0]: | ||
59 | cross_att_count += register_recr(net[1], 0, "up") | ||
60 | elif "mid" in net[0]: | ||
61 | cross_att_count += register_recr(net[1], 0, "mid") | ||
62 | controller.num_att_layers = cross_att_count | ||
diff --git a/models/attention/structured.py b/models/attention/structured.py new file mode 100644 index 0000000..24d889f --- /dev/null +++ b/models/attention/structured.py | |||
@@ -0,0 +1,132 @@ | |||
1 | import torch | ||
2 | |||
3 | from .control import AttentionControl | ||
4 | |||
5 | |||
6 | class StructuredAttentionControl(AttentionControl): | ||
7 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
8 | return attn | ||
9 | |||
10 | def forward(self, x, context=None, mask=None): | ||
11 | h = self.heads | ||
12 | |||
13 | q = self.to_q(x) | ||
14 | |||
15 | if isinstance(context, list): | ||
16 | if self.struct_attn: | ||
17 | out = self.struct_qkv(q, context, mask) | ||
18 | else: | ||
19 | context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context | ||
20 | out = self.normal_qkv(q, context, mask) | ||
21 | else: | ||
22 | context = default(context, x) | ||
23 | out = self.normal_qkv(q, context, mask) | ||
24 | |||
25 | return self.to_out(out) | ||
26 | |||
27 | def struct_qkv(self, q, context, mask): | ||
28 | """ | ||
29 | context: list of [uc, list of conditional context] | ||
30 | """ | ||
31 | uc_context = context[0] | ||
32 | context_k, context_v = context[1]['k'], context[1]['v'] | ||
33 | |||
34 | if isinstance(context_k, list) and isinstance(context_v, list): | ||
35 | out = self.multi_qkv(q, uc_context, context_k, context_v, mask) | ||
36 | elif isinstance(context_k, torch.Tensor) and isinstance(context_v, torch.Tensor): | ||
37 | out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) | ||
38 | else: | ||
39 | raise NotImplementedError | ||
40 | |||
41 | return out | ||
42 | |||
43 | def multi_qkv(self, q, uc_context, context_k, context_v, mask): | ||
44 | h = self.heads | ||
45 | |||
46 | assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0) | ||
47 | true_bs = uc_context.size(0) * h | ||
48 | |||
49 | k_uc, v_uc = self.get_kv(uc_context) | ||
50 | k_c = [self.to_k(c_k) for c_k in context_k] | ||
51 | v_c = [self.to_v(c_v) for c_v in context_v] | ||
52 | |||
53 | q = rearrange(q, 'b n (h d) -> (b h) n d', h=h) | ||
54 | |||
55 | k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h) | ||
56 | v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h) | ||
57 | |||
58 | k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point | ||
59 | v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c] | ||
60 | |||
61 | # get composition | ||
62 | sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale | ||
63 | sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c] | ||
64 | |||
65 | attn_uc = sim_uc.softmax(dim=-1) | ||
66 | attn_c = [sim.softmax(dim=-1) for sim in sim_c] | ||
67 | |||
68 | # get uc output | ||
69 | out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc) | ||
70 | |||
71 | # get c output | ||
72 | if len(v_c) == 1: | ||
73 | out_c_collect = [] | ||
74 | for attn in attn_c: | ||
75 | for v in v_c: | ||
76 | out_c_collect.append(einsum('b i j, b j d -> b i d', attn, v)) | ||
77 | out_c = sum(out_c_collect) / len(out_c_collect) | ||
78 | else: | ||
79 | out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c) | ||
80 | |||
81 | out = torch.cat([out_uc, out_c], dim=0) | ||
82 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | ||
83 | |||
84 | return out | ||
85 | |||
86 | def normal_qkv(self, q, context, mask): | ||
87 | h = self.heads | ||
88 | k = self.to_k(context) | ||
89 | v = self.to_v(context) | ||
90 | |||
91 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | ||
92 | |||
93 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | ||
94 | |||
95 | if exists(mask): | ||
96 | mask = rearrange(mask, 'b ... -> b (...)') | ||
97 | max_neg_value = -torch.finfo(sim.dtype).max | ||
98 | mask = repeat(mask, 'b j -> (b h) () j', h=h) | ||
99 | sim.masked_fill_(~mask, max_neg_value) | ||
100 | |||
101 | # attention, what we cannot get enough of | ||
102 | attn = sim.softmax(dim=-1) | ||
103 | |||
104 | out = einsum('b i j, b j d -> b i d', attn, v) | ||
105 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | ||
106 | |||
107 | return out | ||
108 | |||
109 | def heterogeous_qkv(self, q, uc_context, context_k, context_v, mask): | ||
110 | h = self.heads | ||
111 | k = self.to_k(torch.cat([uc_context, context_k], dim=0)) | ||
112 | v = self.to_v(torch.cat([uc_context, context_v], dim=0)) | ||
113 | |||
114 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | ||
115 | |||
116 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | ||
117 | |||
118 | if exists(mask): | ||
119 | mask = rearrange(mask, 'b ... -> b (...)') | ||
120 | max_neg_value = -torch.finfo(sim.dtype).max | ||
121 | mask = repeat(mask, 'b j -> (b h) () j', h=h) | ||
122 | sim.masked_fill_(~mask, max_neg_value) | ||
123 | |||
124 | # attention, what we cannot get enough of | ||
125 | attn = sim.softmax(dim=-1) | ||
126 | |||
127 | out = einsum('b i j, b j d -> b i d', attn, v) | ||
128 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | ||
129 | return out | ||
130 | |||
131 | def get_kv(self, context): | ||
132 | return self.to_k(context), self.to_v(context) | ||