From 847ec3b6c43c89ef3649715f86ecfed370b6e442 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 24 Oct 2022 07:34:30 +0200 Subject: Update --- models/attention/hook.py | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 models/attention/hook.py (limited to 'models/attention/hook.py') diff --git a/models/attention/hook.py b/models/attention/hook.py new file mode 100644 index 0000000..903de02 --- /dev/null +++ b/models/attention/hook.py @@ -0,0 +1,62 @@ +import torch + + +try: + import xformers.ops + xformers._is_functorch_available = True + MEM_EFFICIENT_ATTN = True +except ImportError: + print("[!] Not using xformers memory efficient attention.") + MEM_EFFICIENT_ATTN = False + + +def register_attention_control(model, controller): + def ca_forward(self, place_in_unet): + def forward(x, context=None, mask=None): + batch_size, sequence_length, dim = x.shape + h = self.heads + q = self.to_q(x) + is_cross = context is not None + context = context if is_cross else x + k = self.to_k(context) + v = self.to_v(context) + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) + + sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale + + if mask is not None: + mask = mask.reshape(batch_size, -1) + max_neg_value = -torch.finfo(sim.dtype).max + mask = mask[:, None, :].repeat(h, 1, 1) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + attn = controller(attn, is_cross, place_in_unet) + out = torch.einsum("b i j, b j d -> b i d", attn, v) + out = self.reshape_batch_dim_to_heads(out) + return self.to_out(out) + + return forward + + def register_recr(net_, count, place_in_unet): + if net_.__class__.__name__ == 'CrossAttention': + net_.forward = ca_forward(net_, place_in_unet) + return count + 1 + elif hasattr(net_, 'children'): + for net__ in net_.children(): + count = register_recr(net__, count, place_in_unet) + return count + + cross_att_count = 0 + sub_nets = model.unet.named_children() + for net in sub_nets: + if "down" in net[0]: + cross_att_count += register_recr(net[1], 0, "down") + elif "up" in net[0]: + cross_att_count += register_recr(net[1], 0, "up") + elif "mid" in net[0]: + cross_att_count += register_recr(net[1], 0, "mid") + controller.num_att_layers = cross_att_count -- cgit v1.2.3-54-g00ecf