From 847ec3b6c43c89ef3649715f86ecfed370b6e442 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 24 Oct 2022 07:34:30 +0200 Subject: Update --- data/csv.py | 33 ++-- dreambooth.py | 34 +++++ models/attention/control.py | 168 +++++++++++++++++++++ models/attention/hook.py | 62 ++++++++ models/attention/structured.py | 132 ++++++++++++++++ .../stable_diffusion/vlpn_stable_diffusion.py | 3 + 6 files changed, 415 insertions(+), 17 deletions(-) create mode 100644 models/attention/control.py create mode 100644 models/attention/hook.py create mode 100644 models/attention/structured.py diff --git a/data/csv.py b/data/csv.py index df15c5a..5144c0a 100644 --- a/data/csv.py +++ b/data/csv.py @@ -99,7 +99,7 @@ class CSVDataModule(pl.LightningDataModule): val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats) + center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, @@ -157,6 +157,17 @@ class CSVDataset(Dataset): def __len__(self): return math.ceil(self._length / self.batch_size) * self.batch_size + def get_image(self, path): + if path in self.image_cache: + return self.image_cache[path] + + image = Image.open(path) + if not image.mode == "RGB": + image = image.convert("RGB") + self.image_cache[path] = image + + return image + def get_example(self, i): item = self.data[i % self.num_instance_images] cache_key = f"{item.instance_image_path}_{item.class_image_path}" @@ -169,30 +180,18 @@ class CSVDataset(Dataset): example["prompts"] = item.prompt example["nprompts"] = item.nprompt - if item.instance_image_path in self.image_cache: - instance_image = self.image_cache[item.instance_image_path] - else: - instance_image = Image.open(item.instance_image_path) - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - self.image_cache[item.instance_image_path] = instance_image - - example["instance_images"] = instance_image + example["instance_images"] = self.get_image(item.instance_image_path) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( item.prompt.format(self.instance_identifier) ) if self.num_class_images != 0: - class_image = Image.open(item.class_image_path) - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - - example["class_images"] = class_image + example["class_images"] = self.get_image(item.class_image_path) example["class_prompt_ids"] = self.prompt_processor.get_input_ids( item.nprompt.format(self.class_identifier) ) - self.cache[item.instance_image_path] = example + self.cache[cache_key] = example return example def __getitem__(self, i): @@ -204,7 +203,7 @@ class CSVDataset(Dataset): example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] - if self.class_identifier is not None: + if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] diff --git a/dreambooth.py b/dreambooth.py index 72c56cd..5c26f12 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -354,6 +354,8 @@ class Checkpointer: text_encoder, output_dir: Path, instance_identifier, + placeholder_token, + placeholder_token_id, sample_image_size, sample_batches, sample_batch_size, @@ -368,11 +370,35 @@ class Checkpointer: self.text_encoder = text_encoder self.output_dir = output_dir self.instance_identifier = instance_identifier + self.placeholder_token = placeholder_token + self.placeholder_token_id = placeholder_token_id self.sample_image_size = sample_image_size self.seed = seed or torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size + @torch.no_grad() + def save_embedding(self, step, postfix): + if self.placeholder_token_id is None: + return + + print("Saving checkpoint for step %d..." % step) + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + unwrapped = self.accelerator.unwrap_model(self.text_encoder) + + # Save a checkpoint + learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] + learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} + + filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) + torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + + del unwrapped + del learned_embeds + @torch.no_grad() def save_model(self): print("Saving model...") @@ -567,6 +593,8 @@ def main(): text_encoder.text_model.final_layer_norm.parameters(), text_encoder.text_model.embeddings.position_embedding.parameters(), )) + else: + placeholder_token_id = None prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -785,6 +813,8 @@ def main(): text_encoder=text_encoder, output_dir=basepath, instance_identifier=instance_identifier, + placeholder_token=args.placeholder_token, + placeholder_token_id=placeholder_token_id, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, @@ -902,6 +932,7 @@ def main(): global_step += 1 if global_step % args.sample_frequency == 0: + checkpointer.save_embedding(global_step, "training") sample_checkpoint = True logs = { @@ -968,6 +999,7 @@ def main(): if min_val_loss > val_loss: accelerator.print( f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + checkpointer.save_embedding(global_step, "milestone") min_val_loss = val_loss if sample_checkpoint and accelerator.is_main_process: @@ -978,6 +1010,7 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") + checkpointer.save_embedding(global_step, "end") checkpointer.save_model() accelerator.end_training() @@ -985,6 +1018,7 @@ def main(): except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") + checkpointer.save_embedding(global_step, "end") checkpointer.save_model() accelerator.end_training() quit() diff --git a/models/attention/control.py b/models/attention/control.py new file mode 100644 index 0000000..248bd9f --- /dev/null +++ b/models/attention/control.py @@ -0,0 +1,168 @@ +import torch +import abc + + +class AttentionControl(abc.ABC): + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + @property + def num_uncond_att_layers(self): + return self.num_att_layers if LOW_RESOURCE else 0 + + @abc.abstractmethod + def forward(self, attn, is_cross: bool, place_in_unet: str): + raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + if LOW_RESOURCE: + attn = self.forward(attn, is_cross, place_in_unet) + else: + h = attn.shape[0] + attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + return attn + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def __init__(self): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + +class EmptyControl(AttentionControl): + def forward(self, attn, is_cross: bool, place_in_unet: str): + return attn + + +class AttentionStore(AttentionControl): + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], + "down_self": [], "mid_self": [], "up_self": []} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= 32 ** 2: # avoid memory overhead + self.step_store[key].append(attn) + return attn + + def between_steps(self): + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] + for key in self.attention_store} + return average_attention + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_store() + self.attention_store = {} + + def __init__(self): + super(AttentionStore, self).__init__() + self.step_store = self.get_empty_store() + self.attention_store = {} + + +class AttentionControlEdit(AttentionStore, abc.ABC): + def step_callback(self, x_t): + if self.local_blend is not None: + x_t = self.local_blend(x_t, self.attention_store) + return x_t + + def replace_self_attention(self, attn_base, att_replace): + if att_replace.shape[2] <= 16 ** 2: + return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) + else: + return att_replace + + @abc.abstractmethod + def replace_cross_attention(self, attn_base, att_replace): + raise NotImplementedError + + def forward(self, attn, is_cross: bool, place_in_unet: str): + super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) + if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): + h = attn.shape[0] // (self.batch_size) + attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) + attn_base, attn_repalce = attn[0], attn[1:] + if is_cross: + alpha_words = self.cross_replace_alpha[self.cur_step] + attn_repalce_new = self.replace_cross_attention( + attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce + attn[1:] = attn_repalce_new + else: + attn[1:] = self.replace_self_attention(attn_base, attn_repalce) + attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) + return attn + + def __init__(self, prompts, num_steps: int, + cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], + self_replace_steps: Union[float, Tuple[float, float]], + local_blend: Optional[LocalBlend]): + super(AttentionControlEdit, self).__init__() + self.batch_size = len(prompts) + self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( + prompts, num_steps, cross_replace_steps, tokenizer).to(device) + if type(self_replace_steps) is float: + self_replace_steps = 0, self_replace_steps + self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) + self.local_blend = local_blend + + +class AttentionReplace(AttentionControlEdit): + def replace_cross_attention(self, attn_base, att_replace): + return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) + + def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, + local_blend: Optional[LocalBlend] = None): + super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) + self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) + + +class AttentionRefine(AttentionControlEdit): + def replace_cross_attention(self, attn_base, att_replace): + attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) + attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) + return attn_replace + + def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, + local_blend: Optional[LocalBlend] = None): + super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) + self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) + self.mapper, alphas = self.mapper.to(device), alphas.to(device) + self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) + + +class AttentionReweight(AttentionControlEdit): + def replace_cross_attention(self, attn_base, att_replace): + if self.prev_controller is not None: + attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) + attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] + return attn_replace + + def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, + local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): + super(AttentionReweight, self).__init__(prompts, num_steps, + cross_replace_steps, self_replace_steps, local_blend) + self.equalizer = equalizer.to(device) + self.prev_controller = controller diff --git a/models/attention/hook.py b/models/attention/hook.py new file mode 100644 index 0000000..903de02 --- /dev/null +++ b/models/attention/hook.py @@ -0,0 +1,62 @@ +import torch + + +try: + import xformers.ops + xformers._is_functorch_available = True + MEM_EFFICIENT_ATTN = True +except ImportError: + print("[!] Not using xformers memory efficient attention.") + MEM_EFFICIENT_ATTN = False + + +def register_attention_control(model, controller): + def ca_forward(self, place_in_unet): + def forward(x, context=None, mask=None): + batch_size, sequence_length, dim = x.shape + h = self.heads + q = self.to_q(x) + is_cross = context is not None + context = context if is_cross else x + k = self.to_k(context) + v = self.to_v(context) + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) + + sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale + + if mask is not None: + mask = mask.reshape(batch_size, -1) + max_neg_value = -torch.finfo(sim.dtype).max + mask = mask[:, None, :].repeat(h, 1, 1) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + attn = controller(attn, is_cross, place_in_unet) + out = torch.einsum("b i j, b j d -> b i d", attn, v) + out = self.reshape_batch_dim_to_heads(out) + return self.to_out(out) + + return forward + + def register_recr(net_, count, place_in_unet): + if net_.__class__.__name__ == 'CrossAttention': + net_.forward = ca_forward(net_, place_in_unet) + return count + 1 + elif hasattr(net_, 'children'): + for net__ in net_.children(): + count = register_recr(net__, count, place_in_unet) + return count + + cross_att_count = 0 + sub_nets = model.unet.named_children() + for net in sub_nets: + if "down" in net[0]: + cross_att_count += register_recr(net[1], 0, "down") + elif "up" in net[0]: + cross_att_count += register_recr(net[1], 0, "up") + elif "mid" in net[0]: + cross_att_count += register_recr(net[1], 0, "mid") + controller.num_att_layers = cross_att_count diff --git a/models/attention/structured.py b/models/attention/structured.py new file mode 100644 index 0000000..24d889f --- /dev/null +++ b/models/attention/structured.py @@ -0,0 +1,132 @@ +import torch + +from .control import AttentionControl + + +class StructuredAttentionControl(AttentionControl): + def forward(self, attn, is_cross: bool, place_in_unet: str): + return attn + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + + if isinstance(context, list): + if self.struct_attn: + out = self.struct_qkv(q, context, mask) + else: + context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context + out = self.normal_qkv(q, context, mask) + else: + context = default(context, x) + out = self.normal_qkv(q, context, mask) + + return self.to_out(out) + + def struct_qkv(self, q, context, mask): + """ + context: list of [uc, list of conditional context] + """ + uc_context = context[0] + context_k, context_v = context[1]['k'], context[1]['v'] + + if isinstance(context_k, list) and isinstance(context_v, list): + out = self.multi_qkv(q, uc_context, context_k, context_v, mask) + elif isinstance(context_k, torch.Tensor) and isinstance(context_v, torch.Tensor): + out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) + else: + raise NotImplementedError + + return out + + def multi_qkv(self, q, uc_context, context_k, context_v, mask): + h = self.heads + + assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0) + true_bs = uc_context.size(0) * h + + k_uc, v_uc = self.get_kv(uc_context) + k_c = [self.to_k(c_k) for c_k in context_k] + v_c = [self.to_v(c_v) for c_v in context_v] + + q = rearrange(q, 'b n (h d) -> (b h) n d', h=h) + + k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h) + v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h) + + k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point + v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c] + + # get composition + sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale + sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c] + + attn_uc = sim_uc.softmax(dim=-1) + attn_c = [sim.softmax(dim=-1) for sim in sim_c] + + # get uc output + out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc) + + # get c output + if len(v_c) == 1: + out_c_collect = [] + for attn in attn_c: + for v in v_c: + out_c_collect.append(einsum('b i j, b j d -> b i d', attn, v)) + out_c = sum(out_c_collect) / len(out_c_collect) + else: + out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c) + + out = torch.cat([out_uc, out_c], dim=0) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + return out + + def normal_qkv(self, q, context, mask): + h = self.heads + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + return out + + def heterogeous_qkv(self, q, uc_context, context_k, context_v, mask): + h = self.heads + k = self.to_k(torch.cat([uc_context, context_k], dim=0)) + v = self.to_v(torch.cat([uc_context, context_v], dim=0)) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return out + + def get_kv(self, context): + return self.to_k(context), self.to_v(context) diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3da0169..e90528d 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -159,6 +159,9 @@ class VlpnStableDiffusion(DiffusionPipeline): batch_size = len(prompt) + if negative_prompt is None: + negative_prompt = "" + if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] * batch_size -- cgit v1.2.3-54-g00ecf