From 186a69104530610f8c2b924f79a04f941e5238c8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Jun 2023 07:33:29 +0200 Subject: Remove convnext --- models/attention/control.py | 216 -------------------------------------------- 1 file changed, 216 deletions(-) delete mode 100644 models/attention/control.py (limited to 'models/attention/control.py') diff --git a/models/attention/control.py b/models/attention/control.py deleted file mode 100644 index ec378c4..0000000 --- a/models/attention/control.py +++ /dev/null @@ -1,216 +0,0 @@ -import torch -import abc - - -class AttentionControl(abc.ABC): - def step_callback(self, x_t): - return x_t - - def between_steps(self): - return - - @property - def num_uncond_att_layers(self): - return self.num_att_layers if LOW_RESOURCE else 0 - - @abc.abstractmethod - def forward(self, attn, is_cross: bool, place_in_unet: str): - raise NotImplementedError - - def __call__(self, attn, is_cross: bool, place_in_unet: str): - if self.cur_att_layer >= self.num_uncond_att_layers: - if LOW_RESOURCE: - attn = self.forward(attn, is_cross, place_in_unet) - else: - h = attn.shape[0] - attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet) - self.cur_att_layer += 1 - if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: - self.cur_att_layer = 0 - self.cur_step += 1 - self.between_steps() - return attn - - def reset(self): - self.cur_step = 0 - self.cur_att_layer = 0 - - def __init__(self): - self.cur_step = 0 - self.num_att_layers = -1 - self.cur_att_layer = 0 - - -class EmptyControl(AttentionControl): - def forward(self, attn, is_cross: bool, place_in_unet: str): - return attn - - -class AttentionStore(AttentionControl): - @staticmethod - def get_empty_store(): - return { - "down_cross": [], - "mid_cross": [], - "up_cross": [], - "down_self": [], - "mid_self": [], - "up_self": [], - } - - def forward(self, attn, is_cross: bool, place_in_unet: str): - key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" - if attn.shape[1] <= 32**2: # avoid memory overhead - self.step_store[key].append(attn) - return attn - - def between_steps(self): - if len(self.attention_store) == 0: - self.attention_store = self.step_store - else: - for key in self.attention_store: - for i in range(len(self.attention_store[key])): - self.attention_store[key][i] += self.step_store[key][i] - self.step_store = self.get_empty_store() - - def get_average_attention(self): - average_attention = { - key: [item / self.cur_step for item in self.attention_store[key]] - for key in self.attention_store - } - return average_attention - - def reset(self): - super(AttentionStore, self).reset() - self.step_store = self.get_empty_store() - self.attention_store = {} - - def __init__(self): - super(AttentionStore, self).__init__() - self.step_store = self.get_empty_store() - self.attention_store = {} - - -class AttentionControlEdit(AttentionStore, abc.ABC): - def step_callback(self, x_t): - if self.local_blend is not None: - x_t = self.local_blend(x_t, self.attention_store) - return x_t - - def replace_self_attention(self, attn_base, att_replace): - if att_replace.shape[2] <= 16**2: - return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) - else: - return att_replace - - @abc.abstractmethod - def replace_cross_attention(self, attn_base, att_replace): - raise NotImplementedError - - def forward(self, attn, is_cross: bool, place_in_unet: str): - super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) - if is_cross or ( - self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1] - ): - h = attn.shape[0] // (self.batch_size) - attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) - attn_base, attn_repalce = attn[0], attn[1:] - if is_cross: - alpha_words = self.cross_replace_alpha[self.cur_step] - attn_repalce_new = ( - self.replace_cross_attention(attn_base, attn_repalce) * alpha_words - + (1 - alpha_words) * attn_repalce - ) - attn[1:] = attn_repalce_new - else: - attn[1:] = self.replace_self_attention(attn_base, attn_repalce) - attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) - return attn - - def __init__( - self, - prompts, - num_steps: int, - cross_replace_steps: Union[ - float, Tuple[float, float], Dict[str, Tuple[float, float]] - ], - self_replace_steps: Union[float, Tuple[float, float]], - local_blend: Optional[LocalBlend], - ): - super(AttentionControlEdit, self).__init__() - self.batch_size = len(prompts) - self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( - prompts, num_steps, cross_replace_steps, tokenizer - ).to(device) - if type(self_replace_steps) is float: - self_replace_steps = 0, self_replace_steps - self.num_self_replace = int(num_steps * self_replace_steps[0]), int( - num_steps * self_replace_steps[1] - ) - self.local_blend = local_blend - - -class AttentionReplace(AttentionControlEdit): - def replace_cross_attention(self, attn_base, att_replace): - return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper) - - def __init__( - self, - prompts, - num_steps: int, - cross_replace_steps: float, - self_replace_steps: float, - local_blend: Optional[LocalBlend] = None, - ): - super(AttentionReplace, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend - ) - self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) - - -class AttentionRefine(AttentionControlEdit): - def replace_cross_attention(self, attn_base, att_replace): - attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) - attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) - return attn_replace - - def __init__( - self, - prompts, - num_steps: int, - cross_replace_steps: float, - self_replace_steps: float, - local_blend: Optional[LocalBlend] = None, - ): - super(AttentionRefine, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend - ) - self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) - self.mapper, alphas = self.mapper.to(device), alphas.to(device) - self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) - - -class AttentionReweight(AttentionControlEdit): - def replace_cross_attention(self, attn_base, att_replace): - if self.prev_controller is not None: - attn_base = self.prev_controller.replace_cross_attention( - attn_base, att_replace - ) - attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] - return attn_replace - - def __init__( - self, - prompts, - num_steps: int, - cross_replace_steps: float, - self_replace_steps: float, - equalizer, - local_blend: Optional[LocalBlend] = None, - controller: Optional[AttentionControlEdit] = None, - ): - super(AttentionReweight, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend - ) - self.equalizer = equalizer.to(device) - self.prev_controller = controller -- cgit v1.2.3-54-g00ecf