From 186a69104530610f8c2b924f79a04f941e5238c8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Jun 2023 07:33:29 +0200 Subject: Remove convnext --- models/attention/structured.py | 145 ----------------------------------------- 1 file changed, 145 deletions(-) delete mode 100644 models/attention/structured.py (limited to 'models/attention/structured.py') diff --git a/models/attention/structured.py b/models/attention/structured.py deleted file mode 100644 index 5bbbc06..0000000 --- a/models/attention/structured.py +++ /dev/null @@ -1,145 +0,0 @@ -import torch - -from .control import AttentionControl - - -class StructuredAttentionControl(AttentionControl): - def forward(self, attn, is_cross: bool, place_in_unet: str): - return attn - - def forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - - if isinstance(context, list): - if self.struct_attn: - out = self.struct_qkv(q, context, mask) - else: - context = torch.cat( - [context[0], context[1]["k"][0]], dim=0 - ) # use key tensor for context - out = self.normal_qkv(q, context, mask) - else: - context = default(context, x) - out = self.normal_qkv(q, context, mask) - - return self.to_out(out) - - def struct_qkv(self, q, context, mask): - """ - context: list of [uc, list of conditional context] - """ - uc_context = context[0] - context_k, context_v = context[1]["k"], context[1]["v"] - - if isinstance(context_k, list) and isinstance(context_v, list): - out = self.multi_qkv(q, uc_context, context_k, context_v, mask) - elif isinstance(context_k, torch.Tensor) and isinstance( - context_v, torch.Tensor - ): - out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) - else: - raise NotImplementedError - - return out - - def multi_qkv(self, q, uc_context, context_k, context_v, mask): - h = self.heads - - assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0) - true_bs = uc_context.size(0) * h - - k_uc, v_uc = self.get_kv(uc_context) - k_c = [self.to_k(c_k) for c_k in context_k] - v_c = [self.to_v(c_v) for c_v in context_v] - - q = rearrange(q, "b n (h d) -> (b h) n d", h=h) - - k_uc = rearrange(k_uc, "b n (h d) -> (b h) n d", h=h) - v_uc = rearrange(v_uc, "b n (h d) -> (b h) n d", h=h) - - k_c = [ - rearrange(k, "b n (h d) -> (b h) n d", h=h) for k in k_c - ] # NOTE: modification point - v_c = [rearrange(v, "b n (h d) -> (b h) n d", h=h) for v in v_c] - - # get composition - sim_uc = einsum("b i d, b j d -> b i j", q[:true_bs], k_uc) * self.scale - sim_c = [ - einsum("b i d, b j d -> b i j", q[true_bs:], k) * self.scale for k in k_c - ] - - attn_uc = sim_uc.softmax(dim=-1) - attn_c = [sim.softmax(dim=-1) for sim in sim_c] - - # get uc output - out_uc = einsum("b i j, b j d -> b i d", attn_uc, v_uc) - - # get c output - if len(v_c) == 1: - out_c_collect = [] - for attn in attn_c: - for v in v_c: - out_c_collect.append(einsum("b i j, b j d -> b i d", attn, v)) - out_c = sum(out_c_collect) / len(out_c_collect) - else: - out_c = sum( - [ - einsum("b i j, b j d -> b i d", attn, v) - for attn, v in zip(attn_c, v_c) - ] - ) / len(v_c) - - out = torch.cat([out_uc, out_c], dim=0) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - - return out - - def normal_qkv(self, q, context, mask): - h = self.heads - k = self.to_k(context) - v = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - - sim = einsum("b i d, b j d -> b i j", q, k) * self.scale - - if exists(mask): - mask = rearrange(mask, "b ... -> b (...)") - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, "b j -> (b h) () j", h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum("b i j, b j d -> b i d", attn, v) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - - return out - - def heterogeous_qkv(self, q, uc_context, context_k, context_v, mask): - h = self.heads - k = self.to_k(torch.cat([uc_context, context_k], dim=0)) - v = self.to_v(torch.cat([uc_context, context_v], dim=0)) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - - sim = einsum("b i d, b j d -> b i j", q, k) * self.scale - - if exists(mask): - mask = rearrange(mask, "b ... -> b (...)") - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, "b j -> (b h) () j", h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum("b i j, b j d -> b i d", attn, v) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - return out - - def get_kv(self, context): - return self.to_k(context), self.to_v(context) -- cgit v1.2.3-54-g00ecf