From 186a69104530610f8c2b924f79a04f941e5238c8 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Thu, 22 Jun 2023 07:33:29 +0200
Subject: Remove convnext

---
 models/attention/control.py                        | 216 ---------------------
 models/attention/hook.py                           |  63 ------
 models/attention/structured.py                     | 145 --------------
 models/convnext/discriminator.py                   |  34 ----
 .../stable_diffusion/vlpn_stable_diffusion.py      |   1 -
 train_ti.py                                        |  10 -
 training/functional.py                             |  13 --
 7 files changed, 482 deletions(-)
 delete mode 100644 models/attention/control.py
 delete mode 100644 models/attention/hook.py
 delete mode 100644 models/attention/structured.py
 delete mode 100644 models/convnext/discriminator.py

diff --git a/models/attention/control.py b/models/attention/control.py
deleted file mode 100644
index ec378c4..0000000
--- a/models/attention/control.py
+++ /dev/null
@@ -1,216 +0,0 @@
-import torch
-import abc
-
-
-class AttentionControl(abc.ABC):
-    def step_callback(self, x_t):
-        return x_t
-
-    def between_steps(self):
-        return
-
-    @property
-    def num_uncond_att_layers(self):
-        return self.num_att_layers if LOW_RESOURCE else 0
-
-    @abc.abstractmethod
-    def forward(self, attn, is_cross: bool, place_in_unet: str):
-        raise NotImplementedError
-
-    def __call__(self, attn, is_cross: bool, place_in_unet: str):
-        if self.cur_att_layer >= self.num_uncond_att_layers:
-            if LOW_RESOURCE:
-                attn = self.forward(attn, is_cross, place_in_unet)
-            else:
-                h = attn.shape[0]
-                attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet)
-        self.cur_att_layer += 1
-        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
-            self.cur_att_layer = 0
-            self.cur_step += 1
-            self.between_steps()
-        return attn
-
-    def reset(self):
-        self.cur_step = 0
-        self.cur_att_layer = 0
-
-    def __init__(self):
-        self.cur_step = 0
-        self.num_att_layers = -1
-        self.cur_att_layer = 0
-
-
-class EmptyControl(AttentionControl):
-    def forward(self, attn, is_cross: bool, place_in_unet: str):
-        return attn
-
-
-class AttentionStore(AttentionControl):
-    @staticmethod
-    def get_empty_store():
-        return {
-            "down_cross": [],
-            "mid_cross": [],
-            "up_cross": [],
-            "down_self": [],
-            "mid_self": [],
-            "up_self": [],
-        }
-
-    def forward(self, attn, is_cross: bool, place_in_unet: str):
-        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
-        if attn.shape[1] <= 32**2:  # avoid memory overhead
-            self.step_store[key].append(attn)
-        return attn
-
-    def between_steps(self):
-        if len(self.attention_store) == 0:
-            self.attention_store = self.step_store
-        else:
-            for key in self.attention_store:
-                for i in range(len(self.attention_store[key])):
-                    self.attention_store[key][i] += self.step_store[key][i]
-        self.step_store = self.get_empty_store()
-
-    def get_average_attention(self):
-        average_attention = {
-            key: [item / self.cur_step for item in self.attention_store[key]]
-            for key in self.attention_store
-        }
-        return average_attention
-
-    def reset(self):
-        super(AttentionStore, self).reset()
-        self.step_store = self.get_empty_store()
-        self.attention_store = {}
-
-    def __init__(self):
-        super(AttentionStore, self).__init__()
-        self.step_store = self.get_empty_store()
-        self.attention_store = {}
-
-
-class AttentionControlEdit(AttentionStore, abc.ABC):
-    def step_callback(self, x_t):
-        if self.local_blend is not None:
-            x_t = self.local_blend(x_t, self.attention_store)
-        return x_t
-
-    def replace_self_attention(self, attn_base, att_replace):
-        if att_replace.shape[2] <= 16**2:
-            return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
-        else:
-            return att_replace
-
-    @abc.abstractmethod
-    def replace_cross_attention(self, attn_base, att_replace):
-        raise NotImplementedError
-
-    def forward(self, attn, is_cross: bool, place_in_unet: str):
-        super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
-        if is_cross or (
-            self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]
-        ):
-            h = attn.shape[0] // (self.batch_size)
-            attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
-            attn_base, attn_repalce = attn[0], attn[1:]
-            if is_cross:
-                alpha_words = self.cross_replace_alpha[self.cur_step]
-                attn_repalce_new = (
-                    self.replace_cross_attention(attn_base, attn_repalce) * alpha_words
-                    + (1 - alpha_words) * attn_repalce
-                )
-                attn[1:] = attn_repalce_new
-            else:
-                attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
-            attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
-        return attn
-
-    def __init__(
-        self,
-        prompts,
-        num_steps: int,
-        cross_replace_steps: Union[
-            float, Tuple[float, float], Dict[str, Tuple[float, float]]
-        ],
-        self_replace_steps: Union[float, Tuple[float, float]],
-        local_blend: Optional[LocalBlend],
-    ):
-        super(AttentionControlEdit, self).__init__()
-        self.batch_size = len(prompts)
-        self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(
-            prompts, num_steps, cross_replace_steps, tokenizer
-        ).to(device)
-        if type(self_replace_steps) is float:
-            self_replace_steps = 0, self_replace_steps
-        self.num_self_replace = int(num_steps * self_replace_steps[0]), int(
-            num_steps * self_replace_steps[1]
-        )
-        self.local_blend = local_blend
-
-
-class AttentionReplace(AttentionControlEdit):
-    def replace_cross_attention(self, attn_base, att_replace):
-        return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper)
-
-    def __init__(
-        self,
-        prompts,
-        num_steps: int,
-        cross_replace_steps: float,
-        self_replace_steps: float,
-        local_blend: Optional[LocalBlend] = None,
-    ):
-        super(AttentionReplace, self).__init__(
-            prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend
-        )
-        self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
-
-
-class AttentionRefine(AttentionControlEdit):
-    def replace_cross_attention(self, attn_base, att_replace):
-        attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
-        attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
-        return attn_replace
-
-    def __init__(
-        self,
-        prompts,
-        num_steps: int,
-        cross_replace_steps: float,
-        self_replace_steps: float,
-        local_blend: Optional[LocalBlend] = None,
-    ):
-        super(AttentionRefine, self).__init__(
-            prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend
-        )
-        self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
-        self.mapper, alphas = self.mapper.to(device), alphas.to(device)
-        self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
-
-
-class AttentionReweight(AttentionControlEdit):
-    def replace_cross_attention(self, attn_base, att_replace):
-        if self.prev_controller is not None:
-            attn_base = self.prev_controller.replace_cross_attention(
-                attn_base, att_replace
-            )
-        attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
-        return attn_replace
-
-    def __init__(
-        self,
-        prompts,
-        num_steps: int,
-        cross_replace_steps: float,
-        self_replace_steps: float,
-        equalizer,
-        local_blend: Optional[LocalBlend] = None,
-        controller: Optional[AttentionControlEdit] = None,
-    ):
-        super(AttentionReweight, self).__init__(
-            prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend
-        )
-        self.equalizer = equalizer.to(device)
-        self.prev_controller = controller
diff --git a/models/attention/hook.py b/models/attention/hook.py
deleted file mode 100644
index 6b5fb68..0000000
--- a/models/attention/hook.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import torch
-
-
-try:
-    import xformers.ops
-
-    xformers._is_functorch_available = True
-    MEM_EFFICIENT_ATTN = True
-except ImportError:
-    print("[!] Not using xformers memory efficient attention.")
-    MEM_EFFICIENT_ATTN = False
-
-
-def register_attention_control(model, controller):
-    def ca_forward(self, place_in_unet):
-        def forward(x, context=None, mask=None):
-            batch_size, sequence_length, dim = x.shape
-            h = self.heads
-            q = self.to_q(x)
-            is_cross = context is not None
-            context = context if is_cross else x
-            k = self.to_k(context)
-            v = self.to_v(context)
-            q = self.reshape_heads_to_batch_dim(q)
-            k = self.reshape_heads_to_batch_dim(k)
-            v = self.reshape_heads_to_batch_dim(v)
-
-            sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
-
-            if mask is not None:
-                mask = mask.reshape(batch_size, -1)
-                max_neg_value = -torch.finfo(sim.dtype).max
-                mask = mask[:, None, :].repeat(h, 1, 1)
-                sim.masked_fill_(~mask, max_neg_value)
-
-            # attention, what we cannot get enough of
-            attn = sim.softmax(dim=-1)
-            attn = controller(attn, is_cross, place_in_unet)
-            out = torch.einsum("b i j, b j d -> b i d", attn, v)
-            out = self.reshape_batch_dim_to_heads(out)
-            return self.to_out(out)
-
-        return forward
-
-    def register_recr(net_, count, place_in_unet):
-        if net_.__class__.__name__ == "CrossAttention":
-            net_.forward = ca_forward(net_, place_in_unet)
-            return count + 1
-        elif hasattr(net_, "children"):
-            for net__ in net_.children():
-                count = register_recr(net__, count, place_in_unet)
-        return count
-
-    cross_att_count = 0
-    sub_nets = model.unet.named_children()
-    for net in sub_nets:
-        if "down" in net[0]:
-            cross_att_count += register_recr(net[1], 0, "down")
-        elif "up" in net[0]:
-            cross_att_count += register_recr(net[1], 0, "up")
-        elif "mid" in net[0]:
-            cross_att_count += register_recr(net[1], 0, "mid")
-    controller.num_att_layers = cross_att_count
diff --git a/models/attention/structured.py b/models/attention/structured.py
deleted file mode 100644
index 5bbbc06..0000000
--- a/models/attention/structured.py
+++ /dev/null
@@ -1,145 +0,0 @@
-import torch
-
-from .control import AttentionControl
-
-
-class StructuredAttentionControl(AttentionControl):
-    def forward(self, attn, is_cross: bool, place_in_unet: str):
-        return attn
-
-    def forward(self, x, context=None, mask=None):
-        h = self.heads
-
-        q = self.to_q(x)
-
-        if isinstance(context, list):
-            if self.struct_attn:
-                out = self.struct_qkv(q, context, mask)
-            else:
-                context = torch.cat(
-                    [context[0], context[1]["k"][0]], dim=0
-                )  # use key tensor for context
-                out = self.normal_qkv(q, context, mask)
-        else:
-            context = default(context, x)
-            out = self.normal_qkv(q, context, mask)
-
-        return self.to_out(out)
-
-    def struct_qkv(self, q, context, mask):
-        """
-        context: list of [uc, list of conditional context]
-        """
-        uc_context = context[0]
-        context_k, context_v = context[1]["k"], context[1]["v"]
-
-        if isinstance(context_k, list) and isinstance(context_v, list):
-            out = self.multi_qkv(q, uc_context, context_k, context_v, mask)
-        elif isinstance(context_k, torch.Tensor) and isinstance(
-            context_v, torch.Tensor
-        ):
-            out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask)
-        else:
-            raise NotImplementedError
-
-        return out
-
-    def multi_qkv(self, q, uc_context, context_k, context_v, mask):
-        h = self.heads
-
-        assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0)
-        true_bs = uc_context.size(0) * h
-
-        k_uc, v_uc = self.get_kv(uc_context)
-        k_c = [self.to_k(c_k) for c_k in context_k]
-        v_c = [self.to_v(c_v) for c_v in context_v]
-
-        q = rearrange(q, "b n (h d) -> (b h) n d", h=h)
-
-        k_uc = rearrange(k_uc, "b n (h d) -> (b h) n d", h=h)
-        v_uc = rearrange(v_uc, "b n (h d) -> (b h) n d", h=h)
-
-        k_c = [
-            rearrange(k, "b n (h d) -> (b h) n d", h=h) for k in k_c
-        ]  # NOTE: modification point
-        v_c = [rearrange(v, "b n (h d) -> (b h) n d", h=h) for v in v_c]
-
-        # get composition
-        sim_uc = einsum("b i d, b j d -> b i j", q[:true_bs], k_uc) * self.scale
-        sim_c = [
-            einsum("b i d, b j d -> b i j", q[true_bs:], k) * self.scale for k in k_c
-        ]
-
-        attn_uc = sim_uc.softmax(dim=-1)
-        attn_c = [sim.softmax(dim=-1) for sim in sim_c]
-
-        # get uc output
-        out_uc = einsum("b i j, b j d -> b i d", attn_uc, v_uc)
-
-        # get c output
-        if len(v_c) == 1:
-            out_c_collect = []
-            for attn in attn_c:
-                for v in v_c:
-                    out_c_collect.append(einsum("b i j, b j d -> b i d", attn, v))
-            out_c = sum(out_c_collect) / len(out_c_collect)
-        else:
-            out_c = sum(
-                [
-                    einsum("b i j, b j d -> b i d", attn, v)
-                    for attn, v in zip(attn_c, v_c)
-                ]
-            ) / len(v_c)
-
-        out = torch.cat([out_uc, out_c], dim=0)
-        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
-
-        return out
-
-    def normal_qkv(self, q, context, mask):
-        h = self.heads
-        k = self.to_k(context)
-        v = self.to_v(context)
-
-        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
-
-        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
-
-        if exists(mask):
-            mask = rearrange(mask, "b ... -> b (...)")
-            max_neg_value = -torch.finfo(sim.dtype).max
-            mask = repeat(mask, "b j -> (b h) () j", h=h)
-            sim.masked_fill_(~mask, max_neg_value)
-
-        # attention, what we cannot get enough of
-        attn = sim.softmax(dim=-1)
-
-        out = einsum("b i j, b j d -> b i d", attn, v)
-        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
-
-        return out
-
-    def heterogeous_qkv(self, q, uc_context, context_k, context_v, mask):
-        h = self.heads
-        k = self.to_k(torch.cat([uc_context, context_k], dim=0))
-        v = self.to_v(torch.cat([uc_context, context_v], dim=0))
-
-        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
-
-        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
-
-        if exists(mask):
-            mask = rearrange(mask, "b ... -> b (...)")
-            max_neg_value = -torch.finfo(sim.dtype).max
-            mask = repeat(mask, "b j -> (b h) () j", h=h)
-            sim.masked_fill_(~mask, max_neg_value)
-
-        # attention, what we cannot get enough of
-        attn = sim.softmax(dim=-1)
-
-        out = einsum("b i j, b j d -> b i d", attn, v)
-        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
-        return out
-
-    def get_kv(self, context):
-        return self.to_k(context), self.to_v(context)
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py
deleted file mode 100644
index 5798bcf..0000000
--- a/models/convnext/discriminator.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import torch
-from timm.models import ConvNeXt
-from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-
-from torch.nn import functional as F
-
-
-class ConvNeXtDiscriminator:
-    def __init__(self, model: ConvNeXt, input_size: int) -> None:
-        self.net = model
-
-        self.input_size = input_size
-
-        self.img_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, -1, 1, 1)
-        self.img_std = torch.tensor(IMAGENET_DEFAULT_STD).view(1, -1, 1, 1)
-
-    def get_score(self, img):
-        pred = self.get_all(img)
-        return torch.softmax(pred, dim=-1)[:, 1]
-
-    def get_all(self, img):
-        img_mean = self.img_mean.to(device=img.device, dtype=img.dtype)
-        img_std = self.img_std.to(device=img.device, dtype=img.dtype)
-
-        img = ((img + 1.0) / 2.0).sub(img_mean).div(img_std)
-
-        img = F.interpolate(
-            img,
-            size=(self.input_size, self.input_size),
-            mode="bicubic",
-            align_corners=True,
-        )
-        pred = self.net(img)
-        return pred
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 16b8456..98703d5 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -28,7 +28,6 @@ from diffusers.utils import logging, randn_tensor
 from transformers import CLIPTextModel, CLIPTokenizer
 
 from models.clip.util import unify_input_ids, get_extended_embeddings
-from util.noise import perlin_noise
 
 logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
 
diff --git a/train_ti.py b/train_ti.py
index da0c03e..7d1ef19 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -698,16 +698,6 @@ def main():
         unet.enable_gradient_checkpointing()
         text_encoder.gradient_checkpointing_enable()
 
-    # convnext = create_model(
-    #     "convnext_tiny",
-    #     pretrained=False,
-    #     num_classes=3,
-    #     drop_path_rate=0.0,
-    # )
-    # convnext.to(accelerator.device, dtype=weight_dtype)
-    # convnext.requires_grad_(False)
-    # convnext.eval()
-
     if len(args.alias_tokens) != 0:
         alias_placeholder_tokens = args.alias_tokens[::2]
         alias_initializer_tokens = args.alias_tokens[1::2]
diff --git a/training/functional.py b/training/functional.py
index 3c7848f..a3d1f08 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -29,11 +29,8 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
 from models.clip.embeddings import ManagedCLIPTextEmbeddings
 from models.clip.util import get_extended_embeddings
 from models.clip.tokenizer import MultiCLIPTokenizer
-from models.convnext.discriminator import ConvNeXtDiscriminator
 from training.util import AverageMeter
 from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler
-from util.slerp import slerp
-from util.noise import perlin_noise
 
 
 def const(result=None):
@@ -349,7 +346,6 @@ def loss_step(
     prior_loss_weight: float,
     seed: int,
     input_pertubation: float,
-    disc: Optional[ConvNeXtDiscriminator],
     min_snr_gamma: int,
     step: int,
     batch: dict[str, Any],
@@ -449,13 +445,6 @@ def loss_step(
 
     loss = loss.mean([1, 2, 3])
 
-    if disc is not None:
-        rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps)
-        rec_latent = rec_latent / vae.config.scaling_factor
-        rec_latent = rec_latent.to(dtype=vae.dtype)
-        rec = vae.decode(rec_latent, return_dict=False)[0]
-        loss = 1 - disc.get_score(rec)
-
     if min_snr_gamma != 0:
         snr = compute_snr(timesteps, noise_scheduler)
         mse_loss_weights = (
@@ -741,7 +730,6 @@ def train(
     guidance_scale: float = 0.0,
     prior_loss_weight: float = 1.0,
     input_pertubation: float = 0.1,
-    disc: Optional[ConvNeXtDiscriminator] = None,
     schedule_sampler: Optional[ScheduleSampler] = None,
     min_snr_gamma: int = 5,
     avg_loss: AverageMeter = AverageMeter(),
@@ -803,7 +791,6 @@ def train(
         prior_loss_weight,
         seed,
         input_pertubation,
-        disc,
         min_snr_gamma,
     )
 
-- 
cgit v1.2.3-70-g09d2