From 186a69104530610f8c2b924f79a04f941e5238c8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Jun 2023 07:33:29 +0200 Subject: Remove convnext --- models/attention/control.py | 216 --------------------------------------- models/attention/hook.py | 63 ------------ models/attention/structured.py | 145 -------------------------- models/convnext/discriminator.py | 34 ------ 4 files changed, 458 deletions(-) delete mode 100644 models/attention/control.py delete mode 100644 models/attention/hook.py delete mode 100644 models/attention/structured.py delete mode 100644 models/convnext/discriminator.py (limited to 'models') diff --git a/models/attention/control.py b/models/attention/control.py deleted file mode 100644 index ec378c4..0000000 --- a/models/attention/control.py +++ /dev/null @@ -1,216 +0,0 @@ -import torch -import abc - - -class AttentionControl(abc.ABC): - def step_callback(self, x_t): - return x_t - - def between_steps(self): - return - - @property - def num_uncond_att_layers(self): - return self.num_att_layers if LOW_RESOURCE else 0 - - @abc.abstractmethod - def forward(self, attn, is_cross: bool, place_in_unet: str): - raise NotImplementedError - - def __call__(self, attn, is_cross: bool, place_in_unet: str): - if self.cur_att_layer >= self.num_uncond_att_layers: - if LOW_RESOURCE: - attn = self.forward(attn, is_cross, place_in_unet) - else: - h = attn.shape[0] - attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet) - self.cur_att_layer += 1 - if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: - self.cur_att_layer = 0 - self.cur_step += 1 - self.between_steps() - return attn - - def reset(self): - self.cur_step = 0 - self.cur_att_layer = 0 - - def __init__(self): - self.cur_step = 0 - self.num_att_layers = -1 - self.cur_att_layer = 0 - - -class EmptyControl(AttentionControl): - def forward(self, attn, is_cross: bool, place_in_unet: str): - return attn - - -class AttentionStore(AttentionControl): - @staticmethod - def get_empty_store(): - return { - "down_cross": [], - "mid_cross": [], - "up_cross": [], - "down_self": [], - "mid_self": [], - "up_self": [], - } - - def forward(self, attn, is_cross: bool, place_in_unet: str): - key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" - if attn.shape[1] <= 32**2: # avoid memory overhead - self.step_store[key].append(attn) - return attn - - def between_steps(self): - if len(self.attention_store) == 0: - self.attention_store = self.step_store - else: - for key in self.attention_store: - for i in range(len(self.attention_store[key])): - self.attention_store[key][i] += self.step_store[key][i] - self.step_store = self.get_empty_store() - - def get_average_attention(self): - average_attention = { - key: [item / self.cur_step for item in self.attention_store[key]] - for key in self.attention_store - } - return average_attention - - def reset(self): - super(AttentionStore, self).reset() - self.step_store = self.get_empty_store() - self.attention_store = {} - - def __init__(self): - super(AttentionStore, self).__init__() - self.step_store = self.get_empty_store() - self.attention_store = {} - - -class AttentionControlEdit(AttentionStore, abc.ABC): - def step_callback(self, x_t): - if self.local_blend is not None: - x_t = self.local_blend(x_t, self.attention_store) - return x_t - - def replace_self_attention(self, attn_base, att_replace): - if att_replace.shape[2] <= 16**2: - return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) - else: - return att_replace - - @abc.abstractmethod - def replace_cross_attention(self, attn_base, att_replace): - raise NotImplementedError - - def forward(self, attn, is_cross: bool, place_in_unet: str): - super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) - if is_cross or ( - self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1] - ): - h = attn.shape[0] // (self.batch_size) - attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) - attn_base, attn_repalce = attn[0], attn[1:] - if is_cross: - alpha_words = self.cross_replace_alpha[self.cur_step] - attn_repalce_new = ( - self.replace_cross_attention(attn_base, attn_repalce) * alpha_words - + (1 - alpha_words) * attn_repalce - ) - attn[1:] = attn_repalce_new - else: - attn[1:] = self.replace_self_attention(attn_base, attn_repalce) - attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) - return attn - - def __init__( - self, - prompts, - num_steps: int, - cross_replace_steps: Union[ - float, Tuple[float, float], Dict[str, Tuple[float, float]] - ], - self_replace_steps: Union[float, Tuple[float, float]], - local_blend: Optional[LocalBlend], - ): - super(AttentionControlEdit, self).__init__() - self.batch_size = len(prompts) - self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( - prompts, num_steps, cross_replace_steps, tokenizer - ).to(device) - if type(self_replace_steps) is float: - self_replace_steps = 0, self_replace_steps - self.num_self_replace = int(num_steps * self_replace_steps[0]), int( - num_steps * self_replace_steps[1] - ) - self.local_blend = local_blend - - -class AttentionReplace(AttentionControlEdit): - def replace_cross_attention(self, attn_base, att_replace): - return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper) - - def __init__( - self, - prompts, - num_steps: int, - cross_replace_steps: float, - self_replace_steps: float, - local_blend: Optional[LocalBlend] = None, - ): - super(AttentionReplace, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend - ) - self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) - - -class AttentionRefine(AttentionControlEdit): - def replace_cross_attention(self, attn_base, att_replace): - attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) - attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) - return attn_replace - - def __init__( - self, - prompts, - num_steps: int, - cross_replace_steps: float, - self_replace_steps: float, - local_blend: Optional[LocalBlend] = None, - ): - super(AttentionRefine, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend - ) - self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) - self.mapper, alphas = self.mapper.to(device), alphas.to(device) - self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) - - -class AttentionReweight(AttentionControlEdit): - def replace_cross_attention(self, attn_base, att_replace): - if self.prev_controller is not None: - attn_base = self.prev_controller.replace_cross_attention( - attn_base, att_replace - ) - attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] - return attn_replace - - def __init__( - self, - prompts, - num_steps: int, - cross_replace_steps: float, - self_replace_steps: float, - equalizer, - local_blend: Optional[LocalBlend] = None, - controller: Optional[AttentionControlEdit] = None, - ): - super(AttentionReweight, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend - ) - self.equalizer = equalizer.to(device) - self.prev_controller = controller diff --git a/models/attention/hook.py b/models/attention/hook.py deleted file mode 100644 index 6b5fb68..0000000 --- a/models/attention/hook.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch - - -try: - import xformers.ops - - xformers._is_functorch_available = True - MEM_EFFICIENT_ATTN = True -except ImportError: - print("[!] Not using xformers memory efficient attention.") - MEM_EFFICIENT_ATTN = False - - -def register_attention_control(model, controller): - def ca_forward(self, place_in_unet): - def forward(x, context=None, mask=None): - batch_size, sequence_length, dim = x.shape - h = self.heads - q = self.to_q(x) - is_cross = context is not None - context = context if is_cross else x - k = self.to_k(context) - v = self.to_v(context) - q = self.reshape_heads_to_batch_dim(q) - k = self.reshape_heads_to_batch_dim(k) - v = self.reshape_heads_to_batch_dim(v) - - sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale - - if mask is not None: - mask = mask.reshape(batch_size, -1) - max_neg_value = -torch.finfo(sim.dtype).max - mask = mask[:, None, :].repeat(h, 1, 1) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - attn = controller(attn, is_cross, place_in_unet) - out = torch.einsum("b i j, b j d -> b i d", attn, v) - out = self.reshape_batch_dim_to_heads(out) - return self.to_out(out) - - return forward - - def register_recr(net_, count, place_in_unet): - if net_.__class__.__name__ == "CrossAttention": - net_.forward = ca_forward(net_, place_in_unet) - return count + 1 - elif hasattr(net_, "children"): - for net__ in net_.children(): - count = register_recr(net__, count, place_in_unet) - return count - - cross_att_count = 0 - sub_nets = model.unet.named_children() - for net in sub_nets: - if "down" in net[0]: - cross_att_count += register_recr(net[1], 0, "down") - elif "up" in net[0]: - cross_att_count += register_recr(net[1], 0, "up") - elif "mid" in net[0]: - cross_att_count += register_recr(net[1], 0, "mid") - controller.num_att_layers = cross_att_count diff --git a/models/attention/structured.py b/models/attention/structured.py deleted file mode 100644 index 5bbbc06..0000000 --- a/models/attention/structured.py +++ /dev/null @@ -1,145 +0,0 @@ -import torch - -from .control import AttentionControl - - -class StructuredAttentionControl(AttentionControl): - def forward(self, attn, is_cross: bool, place_in_unet: str): - return attn - - def forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - - if isinstance(context, list): - if self.struct_attn: - out = self.struct_qkv(q, context, mask) - else: - context = torch.cat( - [context[0], context[1]["k"][0]], dim=0 - ) # use key tensor for context - out = self.normal_qkv(q, context, mask) - else: - context = default(context, x) - out = self.normal_qkv(q, context, mask) - - return self.to_out(out) - - def struct_qkv(self, q, context, mask): - """ - context: list of [uc, list of conditional context] - """ - uc_context = context[0] - context_k, context_v = context[1]["k"], context[1]["v"] - - if isinstance(context_k, list) and isinstance(context_v, list): - out = self.multi_qkv(q, uc_context, context_k, context_v, mask) - elif isinstance(context_k, torch.Tensor) and isinstance( - context_v, torch.Tensor - ): - out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) - else: - raise NotImplementedError - - return out - - def multi_qkv(self, q, uc_context, context_k, context_v, mask): - h = self.heads - - assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0) - true_bs = uc_context.size(0) * h - - k_uc, v_uc = self.get_kv(uc_context) - k_c = [self.to_k(c_k) for c_k in context_k] - v_c = [self.to_v(c_v) for c_v in context_v] - - q = rearrange(q, "b n (h d) -> (b h) n d", h=h) - - k_uc = rearrange(k_uc, "b n (h d) -> (b h) n d", h=h) - v_uc = rearrange(v_uc, "b n (h d) -> (b h) n d", h=h) - - k_c = [ - rearrange(k, "b n (h d) -> (b h) n d", h=h) for k in k_c - ] # NOTE: modification point - v_c = [rearrange(v, "b n (h d) -> (b h) n d", h=h) for v in v_c] - - # get composition - sim_uc = einsum("b i d, b j d -> b i j", q[:true_bs], k_uc) * self.scale - sim_c = [ - einsum("b i d, b j d -> b i j", q[true_bs:], k) * self.scale for k in k_c - ] - - attn_uc = sim_uc.softmax(dim=-1) - attn_c = [sim.softmax(dim=-1) for sim in sim_c] - - # get uc output - out_uc = einsum("b i j, b j d -> b i d", attn_uc, v_uc) - - # get c output - if len(v_c) == 1: - out_c_collect = [] - for attn in attn_c: - for v in v_c: - out_c_collect.append(einsum("b i j, b j d -> b i d", attn, v)) - out_c = sum(out_c_collect) / len(out_c_collect) - else: - out_c = sum( - [ - einsum("b i j, b j d -> b i d", attn, v) - for attn, v in zip(attn_c, v_c) - ] - ) / len(v_c) - - out = torch.cat([out_uc, out_c], dim=0) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - - return out - - def normal_qkv(self, q, context, mask): - h = self.heads - k = self.to_k(context) - v = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - - sim = einsum("b i d, b j d -> b i j", q, k) * self.scale - - if exists(mask): - mask = rearrange(mask, "b ... -> b (...)") - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, "b j -> (b h) () j", h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum("b i j, b j d -> b i d", attn, v) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - - return out - - def heterogeous_qkv(self, q, uc_context, context_k, context_v, mask): - h = self.heads - k = self.to_k(torch.cat([uc_context, context_k], dim=0)) - v = self.to_v(torch.cat([uc_context, context_v], dim=0)) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - - sim = einsum("b i d, b j d -> b i j", q, k) * self.scale - - if exists(mask): - mask = rearrange(mask, "b ... -> b (...)") - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, "b j -> (b h) () j", h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum("b i j, b j d -> b i d", attn, v) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - return out - - def get_kv(self, context): - return self.to_k(context), self.to_v(context) diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py deleted file mode 100644 index 5798bcf..0000000 --- a/models/convnext/discriminator.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from timm.models import ConvNeXt -from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD - -from torch.nn import functional as F - - -class ConvNeXtDiscriminator: - def __init__(self, model: ConvNeXt, input_size: int) -> None: - self.net = model - - self.input_size = input_size - - self.img_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, -1, 1, 1) - self.img_std = torch.tensor(IMAGENET_DEFAULT_STD).view(1, -1, 1, 1) - - def get_score(self, img): - pred = self.get_all(img) - return torch.softmax(pred, dim=-1)[:, 1] - - def get_all(self, img): - img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) - img_std = self.img_std.to(device=img.device, dtype=img.dtype) - - img = ((img + 1.0) / 2.0).sub(img_mean).div(img_std) - - img = F.interpolate( - img, - size=(self.input_size, self.input_size), - mode="bicubic", - align_corners=True, - ) - pred = self.net(img) - return pred -- cgit v1.2.3-70-g09d2