From 186a69104530610f8c2b924f79a04f941e5238c8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Jun 2023 07:33:29 +0200 Subject: Remove convnext --- models/attention/hook.py | 63 ------------------------------------------------ 1 file changed, 63 deletions(-) delete mode 100644 models/attention/hook.py (limited to 'models/attention/hook.py') diff --git a/models/attention/hook.py b/models/attention/hook.py deleted file mode 100644 index 6b5fb68..0000000 --- a/models/attention/hook.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch - - -try: - import xformers.ops - - xformers._is_functorch_available = True - MEM_EFFICIENT_ATTN = True -except ImportError: - print("[!] Not using xformers memory efficient attention.") - MEM_EFFICIENT_ATTN = False - - -def register_attention_control(model, controller): - def ca_forward(self, place_in_unet): - def forward(x, context=None, mask=None): - batch_size, sequence_length, dim = x.shape - h = self.heads - q = self.to_q(x) - is_cross = context is not None - context = context if is_cross else x - k = self.to_k(context) - v = self.to_v(context) - q = self.reshape_heads_to_batch_dim(q) - k = self.reshape_heads_to_batch_dim(k) - v = self.reshape_heads_to_batch_dim(v) - - sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale - - if mask is not None: - mask = mask.reshape(batch_size, -1) - max_neg_value = -torch.finfo(sim.dtype).max - mask = mask[:, None, :].repeat(h, 1, 1) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - attn = controller(attn, is_cross, place_in_unet) - out = torch.einsum("b i j, b j d -> b i d", attn, v) - out = self.reshape_batch_dim_to_heads(out) - return self.to_out(out) - - return forward - - def register_recr(net_, count, place_in_unet): - if net_.__class__.__name__ == "CrossAttention": - net_.forward = ca_forward(net_, place_in_unet) - return count + 1 - elif hasattr(net_, "children"): - for net__ in net_.children(): - count = register_recr(net__, count, place_in_unet) - return count - - cross_att_count = 0 - sub_nets = model.unet.named_children() - for net in sub_nets: - if "down" in net[0]: - cross_att_count += register_recr(net[1], 0, "down") - elif "up" in net[0]: - cross_att_count += register_recr(net[1], 0, "up") - elif "mid" in net[0]: - cross_att_count += register_recr(net[1], 0, "mid") - controller.num_att_layers = cross_att_count -- cgit v1.2.3-54-g00ecf