diff options
| -rw-r--r-- | data/csv.py | 33 | ||||
| -rw-r--r-- | dreambooth.py | 34 | ||||
| -rw-r--r-- | models/attention/control.py | 168 | ||||
| -rw-r--r-- | models/attention/hook.py | 62 | ||||
| -rw-r--r-- | models/attention/structured.py | 132 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 3 |
6 files changed, 415 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py index df15c5a..5144c0a 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -99,7 +99,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 99 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 99 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
| 100 | instance_identifier=self.instance_identifier, | 100 | instance_identifier=self.instance_identifier, |
| 101 | size=self.size, interpolation=self.interpolation, | 101 | size=self.size, interpolation=self.interpolation, |
| 102 | center_crop=self.center_crop, repeats=self.repeats) | 102 | center_crop=self.center_crop) |
| 103 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 103 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| 104 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 104 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
| 105 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 105 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
| @@ -157,6 +157,17 @@ class CSVDataset(Dataset): | |||
| 157 | def __len__(self): | 157 | def __len__(self): |
| 158 | return math.ceil(self._length / self.batch_size) * self.batch_size | 158 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 159 | 159 | ||
| 160 | def get_image(self, path): | ||
| 161 | if path in self.image_cache: | ||
| 162 | return self.image_cache[path] | ||
| 163 | |||
| 164 | image = Image.open(path) | ||
| 165 | if not image.mode == "RGB": | ||
| 166 | image = image.convert("RGB") | ||
| 167 | self.image_cache[path] = image | ||
| 168 | |||
| 169 | return image | ||
| 170 | |||
| 160 | def get_example(self, i): | 171 | def get_example(self, i): |
| 161 | item = self.data[i % self.num_instance_images] | 172 | item = self.data[i % self.num_instance_images] |
| 162 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" | 173 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" |
| @@ -169,30 +180,18 @@ class CSVDataset(Dataset): | |||
| 169 | example["prompts"] = item.prompt | 180 | example["prompts"] = item.prompt |
| 170 | example["nprompts"] = item.nprompt | 181 | example["nprompts"] = item.nprompt |
| 171 | 182 | ||
| 172 | if item.instance_image_path in self.image_cache: | 183 | example["instance_images"] = self.get_image(item.instance_image_path) |
| 173 | instance_image = self.image_cache[item.instance_image_path] | ||
| 174 | else: | ||
| 175 | instance_image = Image.open(item.instance_image_path) | ||
| 176 | if not instance_image.mode == "RGB": | ||
| 177 | instance_image = instance_image.convert("RGB") | ||
| 178 | self.image_cache[item.instance_image_path] = instance_image | ||
| 179 | |||
| 180 | example["instance_images"] = instance_image | ||
| 181 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 184 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
| 182 | item.prompt.format(self.instance_identifier) | 185 | item.prompt.format(self.instance_identifier) |
| 183 | ) | 186 | ) |
| 184 | 187 | ||
| 185 | if self.num_class_images != 0: | 188 | if self.num_class_images != 0: |
| 186 | class_image = Image.open(item.class_image_path) | 189 | example["class_images"] = self.get_image(item.class_image_path) |
| 187 | if not class_image.mode == "RGB": | ||
| 188 | class_image = class_image.convert("RGB") | ||
| 189 | |||
| 190 | example["class_images"] = class_image | ||
| 191 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( | 190 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( |
| 192 | item.nprompt.format(self.class_identifier) | 191 | item.nprompt.format(self.class_identifier) |
| 193 | ) | 192 | ) |
| 194 | 193 | ||
| 195 | self.cache[item.instance_image_path] = example | 194 | self.cache[cache_key] = example |
| 196 | return example | 195 | return example |
| 197 | 196 | ||
| 198 | def __getitem__(self, i): | 197 | def __getitem__(self, i): |
| @@ -204,7 +203,7 @@ class CSVDataset(Dataset): | |||
| 204 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 203 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 205 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 204 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
| 206 | 205 | ||
| 207 | if self.class_identifier is not None: | 206 | if self.num_class_images != 0: |
| 208 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 207 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
| 209 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 208 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] |
| 210 | 209 | ||
diff --git a/dreambooth.py b/dreambooth.py index 72c56cd..5c26f12 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -354,6 +354,8 @@ class Checkpointer: | |||
| 354 | text_encoder, | 354 | text_encoder, |
| 355 | output_dir: Path, | 355 | output_dir: Path, |
| 356 | instance_identifier, | 356 | instance_identifier, |
| 357 | placeholder_token, | ||
| 358 | placeholder_token_id, | ||
| 357 | sample_image_size, | 359 | sample_image_size, |
| 358 | sample_batches, | 360 | sample_batches, |
| 359 | sample_batch_size, | 361 | sample_batch_size, |
| @@ -368,12 +370,36 @@ class Checkpointer: | |||
| 368 | self.text_encoder = text_encoder | 370 | self.text_encoder = text_encoder |
| 369 | self.output_dir = output_dir | 371 | self.output_dir = output_dir |
| 370 | self.instance_identifier = instance_identifier | 372 | self.instance_identifier = instance_identifier |
| 373 | self.placeholder_token = placeholder_token | ||
| 374 | self.placeholder_token_id = placeholder_token_id | ||
| 371 | self.sample_image_size = sample_image_size | 375 | self.sample_image_size = sample_image_size |
| 372 | self.seed = seed or torch.random.seed() | 376 | self.seed = seed or torch.random.seed() |
| 373 | self.sample_batches = sample_batches | 377 | self.sample_batches = sample_batches |
| 374 | self.sample_batch_size = sample_batch_size | 378 | self.sample_batch_size = sample_batch_size |
| 375 | 379 | ||
| 376 | @torch.no_grad() | 380 | @torch.no_grad() |
| 381 | def save_embedding(self, step, postfix): | ||
| 382 | if self.placeholder_token_id is None: | ||
| 383 | return | ||
| 384 | |||
| 385 | print("Saving checkpoint for step %d..." % step) | ||
| 386 | |||
| 387 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
| 388 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
| 389 | |||
| 390 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | ||
| 391 | |||
| 392 | # Save a checkpoint | ||
| 393 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | ||
| 394 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | ||
| 395 | |||
| 396 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | ||
| 397 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
| 398 | |||
| 399 | del unwrapped | ||
| 400 | del learned_embeds | ||
| 401 | |||
| 402 | @torch.no_grad() | ||
| 377 | def save_model(self): | 403 | def save_model(self): |
| 378 | print("Saving model...") | 404 | print("Saving model...") |
| 379 | 405 | ||
| @@ -567,6 +593,8 @@ def main(): | |||
| 567 | text_encoder.text_model.final_layer_norm.parameters(), | 593 | text_encoder.text_model.final_layer_norm.parameters(), |
| 568 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 594 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
| 569 | )) | 595 | )) |
| 596 | else: | ||
| 597 | placeholder_token_id = None | ||
| 570 | 598 | ||
| 571 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 599 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 572 | 600 | ||
| @@ -785,6 +813,8 @@ def main(): | |||
| 785 | text_encoder=text_encoder, | 813 | text_encoder=text_encoder, |
| 786 | output_dir=basepath, | 814 | output_dir=basepath, |
| 787 | instance_identifier=instance_identifier, | 815 | instance_identifier=instance_identifier, |
| 816 | placeholder_token=args.placeholder_token, | ||
| 817 | placeholder_token_id=placeholder_token_id, | ||
| 788 | sample_image_size=args.sample_image_size, | 818 | sample_image_size=args.sample_image_size, |
| 789 | sample_batch_size=args.sample_batch_size, | 819 | sample_batch_size=args.sample_batch_size, |
| 790 | sample_batches=args.sample_batches, | 820 | sample_batches=args.sample_batches, |
| @@ -902,6 +932,7 @@ def main(): | |||
| 902 | global_step += 1 | 932 | global_step += 1 |
| 903 | 933 | ||
| 904 | if global_step % args.sample_frequency == 0: | 934 | if global_step % args.sample_frequency == 0: |
| 935 | checkpointer.save_embedding(global_step, "training") | ||
| 905 | sample_checkpoint = True | 936 | sample_checkpoint = True |
| 906 | 937 | ||
| 907 | logs = { | 938 | logs = { |
| @@ -968,6 +999,7 @@ def main(): | |||
| 968 | if min_val_loss > val_loss: | 999 | if min_val_loss > val_loss: |
| 969 | accelerator.print( | 1000 | accelerator.print( |
| 970 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 1001 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
| 1002 | checkpointer.save_embedding(global_step, "milestone") | ||
| 971 | min_val_loss = val_loss | 1003 | min_val_loss = val_loss |
| 972 | 1004 | ||
| 973 | if sample_checkpoint and accelerator.is_main_process: | 1005 | if sample_checkpoint and accelerator.is_main_process: |
| @@ -978,6 +1010,7 @@ def main(): | |||
| 978 | # Create the pipeline using using the trained modules and save it. | 1010 | # Create the pipeline using using the trained modules and save it. |
| 979 | if accelerator.is_main_process: | 1011 | if accelerator.is_main_process: |
| 980 | print("Finished! Saving final checkpoint and resume state.") | 1012 | print("Finished! Saving final checkpoint and resume state.") |
| 1013 | checkpointer.save_embedding(global_step, "end") | ||
| 981 | checkpointer.save_model() | 1014 | checkpointer.save_model() |
| 982 | 1015 | ||
| 983 | accelerator.end_training() | 1016 | accelerator.end_training() |
| @@ -985,6 +1018,7 @@ def main(): | |||
| 985 | except KeyboardInterrupt: | 1018 | except KeyboardInterrupt: |
| 986 | if accelerator.is_main_process: | 1019 | if accelerator.is_main_process: |
| 987 | print("Interrupted, saving checkpoint and resume state...") | 1020 | print("Interrupted, saving checkpoint and resume state...") |
| 1021 | checkpointer.save_embedding(global_step, "end") | ||
| 988 | checkpointer.save_model() | 1022 | checkpointer.save_model() |
| 989 | accelerator.end_training() | 1023 | accelerator.end_training() |
| 990 | quit() | 1024 | quit() |
diff --git a/models/attention/control.py b/models/attention/control.py new file mode 100644 index 0000000..248bd9f --- /dev/null +++ b/models/attention/control.py | |||
| @@ -0,0 +1,168 @@ | |||
| 1 | import torch | ||
| 2 | import abc | ||
| 3 | |||
| 4 | |||
| 5 | class AttentionControl(abc.ABC): | ||
| 6 | def step_callback(self, x_t): | ||
| 7 | return x_t | ||
| 8 | |||
| 9 | def between_steps(self): | ||
| 10 | return | ||
| 11 | |||
| 12 | @property | ||
| 13 | def num_uncond_att_layers(self): | ||
| 14 | return self.num_att_layers if LOW_RESOURCE else 0 | ||
| 15 | |||
| 16 | @abc.abstractmethod | ||
| 17 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
| 18 | raise NotImplementedError | ||
| 19 | |||
| 20 | def __call__(self, attn, is_cross: bool, place_in_unet: str): | ||
| 21 | if self.cur_att_layer >= self.num_uncond_att_layers: | ||
| 22 | if LOW_RESOURCE: | ||
| 23 | attn = self.forward(attn, is_cross, place_in_unet) | ||
| 24 | else: | ||
| 25 | h = attn.shape[0] | ||
| 26 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) | ||
| 27 | self.cur_att_layer += 1 | ||
| 28 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: | ||
| 29 | self.cur_att_layer = 0 | ||
| 30 | self.cur_step += 1 | ||
| 31 | self.between_steps() | ||
| 32 | return attn | ||
| 33 | |||
| 34 | def reset(self): | ||
| 35 | self.cur_step = 0 | ||
| 36 | self.cur_att_layer = 0 | ||
| 37 | |||
| 38 | def __init__(self): | ||
| 39 | self.cur_step = 0 | ||
| 40 | self.num_att_layers = -1 | ||
| 41 | self.cur_att_layer = 0 | ||
| 42 | |||
| 43 | |||
| 44 | class EmptyControl(AttentionControl): | ||
| 45 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
| 46 | return attn | ||
| 47 | |||
| 48 | |||
| 49 | class AttentionStore(AttentionControl): | ||
| 50 | @staticmethod | ||
| 51 | def get_empty_store(): | ||
| 52 | return {"down_cross": [], "mid_cross": [], "up_cross": [], | ||
| 53 | "down_self": [], "mid_self": [], "up_self": []} | ||
| 54 | |||
| 55 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
| 56 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | ||
| 57 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead | ||
| 58 | self.step_store[key].append(attn) | ||
| 59 | return attn | ||
| 60 | |||
| 61 | def between_steps(self): | ||
| 62 | if len(self.attention_store) == 0: | ||
| 63 | self.attention_store = self.step_store | ||
| 64 | else: | ||
| 65 | for key in self.attention_store: | ||
| 66 | for i in range(len(self.attention_store[key])): | ||
| 67 | self.attention_store[key][i] += self.step_store[key][i] | ||
| 68 | self.step_store = self.get_empty_store() | ||
| 69 | |||
| 70 | def get_average_attention(self): | ||
| 71 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] | ||
| 72 | for key in self.attention_store} | ||
| 73 | return average_attention | ||
| 74 | |||
| 75 | def reset(self): | ||
| 76 | super(AttentionStore, self).reset() | ||
| 77 | self.step_store = self.get_empty_store() | ||
| 78 | self.attention_store = {} | ||
| 79 | |||
| 80 | def __init__(self): | ||
| 81 | super(AttentionStore, self).__init__() | ||
| 82 | self.step_store = self.get_empty_store() | ||
| 83 | self.attention_store = {} | ||
| 84 | |||
| 85 | |||
| 86 | class AttentionControlEdit(AttentionStore, abc.ABC): | ||
| 87 | def step_callback(self, x_t): | ||
| 88 | if self.local_blend is not None: | ||
| 89 | x_t = self.local_blend(x_t, self.attention_store) | ||
| 90 | return x_t | ||
| 91 | |||
| 92 | def replace_self_attention(self, attn_base, att_replace): | ||
| 93 | if att_replace.shape[2] <= 16 ** 2: | ||
| 94 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) | ||
| 95 | else: | ||
| 96 | return att_replace | ||
| 97 | |||
| 98 | @abc.abstractmethod | ||
| 99 | def replace_cross_attention(self, attn_base, att_replace): | ||
| 100 | raise NotImplementedError | ||
| 101 | |||
| 102 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
| 103 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) | ||
| 104 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): | ||
| 105 | h = attn.shape[0] // (self.batch_size) | ||
| 106 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) | ||
| 107 | attn_base, attn_repalce = attn[0], attn[1:] | ||
| 108 | if is_cross: | ||
| 109 | alpha_words = self.cross_replace_alpha[self.cur_step] | ||
| 110 | attn_repalce_new = self.replace_cross_attention( | ||
| 111 | attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce | ||
| 112 | attn[1:] = attn_repalce_new | ||
| 113 | else: | ||
| 114 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) | ||
| 115 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) | ||
| 116 | return attn | ||
| 117 | |||
| 118 | def __init__(self, prompts, num_steps: int, | ||
| 119 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], | ||
| 120 | self_replace_steps: Union[float, Tuple[float, float]], | ||
| 121 | local_blend: Optional[LocalBlend]): | ||
| 122 | super(AttentionControlEdit, self).__init__() | ||
| 123 | self.batch_size = len(prompts) | ||
| 124 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( | ||
| 125 | prompts, num_steps, cross_replace_steps, tokenizer).to(device) | ||
| 126 | if type(self_replace_steps) is float: | ||
| 127 | self_replace_steps = 0, self_replace_steps | ||
| 128 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) | ||
| 129 | self.local_blend = local_blend | ||
| 130 | |||
| 131 | |||
| 132 | class AttentionReplace(AttentionControlEdit): | ||
| 133 | def replace_cross_attention(self, attn_base, att_replace): | ||
| 134 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) | ||
| 135 | |||
| 136 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, | ||
| 137 | local_blend: Optional[LocalBlend] = None): | ||
| 138 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) | ||
| 139 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) | ||
| 140 | |||
| 141 | |||
| 142 | class AttentionRefine(AttentionControlEdit): | ||
| 143 | def replace_cross_attention(self, attn_base, att_replace): | ||
| 144 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) | ||
| 145 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) | ||
| 146 | return attn_replace | ||
| 147 | |||
| 148 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, | ||
| 149 | local_blend: Optional[LocalBlend] = None): | ||
| 150 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) | ||
| 151 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) | ||
| 152 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) | ||
| 153 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) | ||
| 154 | |||
| 155 | |||
| 156 | class AttentionReweight(AttentionControlEdit): | ||
| 157 | def replace_cross_attention(self, attn_base, att_replace): | ||
| 158 | if self.prev_controller is not None: | ||
| 159 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) | ||
| 160 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] | ||
| 161 | return attn_replace | ||
| 162 | |||
| 163 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, | ||
| 164 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): | ||
| 165 | super(AttentionReweight, self).__init__(prompts, num_steps, | ||
| 166 | cross_replace_steps, self_replace_steps, local_blend) | ||
| 167 | self.equalizer = equalizer.to(device) | ||
| 168 | self.prev_controller = controller | ||
diff --git a/models/attention/hook.py b/models/attention/hook.py new file mode 100644 index 0000000..903de02 --- /dev/null +++ b/models/attention/hook.py | |||
| @@ -0,0 +1,62 @@ | |||
| 1 | import torch | ||
| 2 | |||
| 3 | |||
| 4 | try: | ||
| 5 | import xformers.ops | ||
| 6 | xformers._is_functorch_available = True | ||
| 7 | MEM_EFFICIENT_ATTN = True | ||
| 8 | except ImportError: | ||
| 9 | print("[!] Not using xformers memory efficient attention.") | ||
| 10 | MEM_EFFICIENT_ATTN = False | ||
| 11 | |||
| 12 | |||
| 13 | def register_attention_control(model, controller): | ||
| 14 | def ca_forward(self, place_in_unet): | ||
| 15 | def forward(x, context=None, mask=None): | ||
| 16 | batch_size, sequence_length, dim = x.shape | ||
| 17 | h = self.heads | ||
| 18 | q = self.to_q(x) | ||
| 19 | is_cross = context is not None | ||
| 20 | context = context if is_cross else x | ||
| 21 | k = self.to_k(context) | ||
| 22 | v = self.to_v(context) | ||
| 23 | q = self.reshape_heads_to_batch_dim(q) | ||
| 24 | k = self.reshape_heads_to_batch_dim(k) | ||
| 25 | v = self.reshape_heads_to_batch_dim(v) | ||
| 26 | |||
| 27 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale | ||
| 28 | |||
| 29 | if mask is not None: | ||
| 30 | mask = mask.reshape(batch_size, -1) | ||
| 31 | max_neg_value = -torch.finfo(sim.dtype).max | ||
| 32 | mask = mask[:, None, :].repeat(h, 1, 1) | ||
| 33 | sim.masked_fill_(~mask, max_neg_value) | ||
| 34 | |||
| 35 | # attention, what we cannot get enough of | ||
| 36 | attn = sim.softmax(dim=-1) | ||
| 37 | attn = controller(attn, is_cross, place_in_unet) | ||
| 38 | out = torch.einsum("b i j, b j d -> b i d", attn, v) | ||
| 39 | out = self.reshape_batch_dim_to_heads(out) | ||
| 40 | return self.to_out(out) | ||
| 41 | |||
| 42 | return forward | ||
| 43 | |||
| 44 | def register_recr(net_, count, place_in_unet): | ||
| 45 | if net_.__class__.__name__ == 'CrossAttention': | ||
| 46 | net_.forward = ca_forward(net_, place_in_unet) | ||
| 47 | return count + 1 | ||
| 48 | elif hasattr(net_, 'children'): | ||
| 49 | for net__ in net_.children(): | ||
| 50 | count = register_recr(net__, count, place_in_unet) | ||
| 51 | return count | ||
| 52 | |||
| 53 | cross_att_count = 0 | ||
| 54 | sub_nets = model.unet.named_children() | ||
| 55 | for net in sub_nets: | ||
| 56 | if "down" in net[0]: | ||
| 57 | cross_att_count += register_recr(net[1], 0, "down") | ||
| 58 | elif "up" in net[0]: | ||
| 59 | cross_att_count += register_recr(net[1], 0, "up") | ||
| 60 | elif "mid" in net[0]: | ||
| 61 | cross_att_count += register_recr(net[1], 0, "mid") | ||
| 62 | controller.num_att_layers = cross_att_count | ||
diff --git a/models/attention/structured.py b/models/attention/structured.py new file mode 100644 index 0000000..24d889f --- /dev/null +++ b/models/attention/structured.py | |||
| @@ -0,0 +1,132 @@ | |||
| 1 | import torch | ||
| 2 | |||
| 3 | from .control import AttentionControl | ||
| 4 | |||
| 5 | |||
| 6 | class StructuredAttentionControl(AttentionControl): | ||
| 7 | def forward(self, attn, is_cross: bool, place_in_unet: str): | ||
| 8 | return attn | ||
| 9 | |||
| 10 | def forward(self, x, context=None, mask=None): | ||
| 11 | h = self.heads | ||
| 12 | |||
| 13 | q = self.to_q(x) | ||
| 14 | |||
| 15 | if isinstance(context, list): | ||
| 16 | if self.struct_attn: | ||
| 17 | out = self.struct_qkv(q, context, mask) | ||
| 18 | else: | ||
| 19 | context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context | ||
| 20 | out = self.normal_qkv(q, context, mask) | ||
| 21 | else: | ||
| 22 | context = default(context, x) | ||
| 23 | out = self.normal_qkv(q, context, mask) | ||
| 24 | |||
| 25 | return self.to_out(out) | ||
| 26 | |||
| 27 | def struct_qkv(self, q, context, mask): | ||
| 28 | """ | ||
| 29 | context: list of [uc, list of conditional context] | ||
| 30 | """ | ||
| 31 | uc_context = context[0] | ||
| 32 | context_k, context_v = context[1]['k'], context[1]['v'] | ||
| 33 | |||
| 34 | if isinstance(context_k, list) and isinstance(context_v, list): | ||
| 35 | out = self.multi_qkv(q, uc_context, context_k, context_v, mask) | ||
| 36 | elif isinstance(context_k, torch.Tensor) and isinstance(context_v, torch.Tensor): | ||
| 37 | out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) | ||
| 38 | else: | ||
| 39 | raise NotImplementedError | ||
| 40 | |||
| 41 | return out | ||
| 42 | |||
| 43 | def multi_qkv(self, q, uc_context, context_k, context_v, mask): | ||
| 44 | h = self.heads | ||
| 45 | |||
| 46 | assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0) | ||
| 47 | true_bs = uc_context.size(0) * h | ||
| 48 | |||
| 49 | k_uc, v_uc = self.get_kv(uc_context) | ||
| 50 | k_c = [self.to_k(c_k) for c_k in context_k] | ||
| 51 | v_c = [self.to_v(c_v) for c_v in context_v] | ||
| 52 | |||
| 53 | q = rearrange(q, 'b n (h d) -> (b h) n d', h=h) | ||
| 54 | |||
| 55 | k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h) | ||
| 56 | v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h) | ||
| 57 | |||
| 58 | k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point | ||
| 59 | v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c] | ||
| 60 | |||
| 61 | # get composition | ||
| 62 | sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale | ||
| 63 | sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c] | ||
| 64 | |||
| 65 | attn_uc = sim_uc.softmax(dim=-1) | ||
| 66 | attn_c = [sim.softmax(dim=-1) for sim in sim_c] | ||
| 67 | |||
| 68 | # get uc output | ||
| 69 | out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc) | ||
| 70 | |||
| 71 | # get c output | ||
| 72 | if len(v_c) == 1: | ||
| 73 | out_c_collect = [] | ||
| 74 | for attn in attn_c: | ||
| 75 | for v in v_c: | ||
| 76 | out_c_collect.append(einsum('b i j, b j d -> b i d', attn, v)) | ||
| 77 | out_c = sum(out_c_collect) / len(out_c_collect) | ||
| 78 | else: | ||
| 79 | out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c) | ||
| 80 | |||
| 81 | out = torch.cat([out_uc, out_c], dim=0) | ||
| 82 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | ||
| 83 | |||
| 84 | return out | ||
| 85 | |||
| 86 | def normal_qkv(self, q, context, mask): | ||
| 87 | h = self.heads | ||
| 88 | k = self.to_k(context) | ||
| 89 | v = self.to_v(context) | ||
| 90 | |||
| 91 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | ||
| 92 | |||
| 93 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | ||
| 94 | |||
| 95 | if exists(mask): | ||
| 96 | mask = rearrange(mask, 'b ... -> b (...)') | ||
| 97 | max_neg_value = -torch.finfo(sim.dtype).max | ||
| 98 | mask = repeat(mask, 'b j -> (b h) () j', h=h) | ||
| 99 | sim.masked_fill_(~mask, max_neg_value) | ||
| 100 | |||
| 101 | # attention, what we cannot get enough of | ||
| 102 | attn = sim.softmax(dim=-1) | ||
| 103 | |||
| 104 | out = einsum('b i j, b j d -> b i d', attn, v) | ||
| 105 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | ||
| 106 | |||
| 107 | return out | ||
| 108 | |||
| 109 | def heterogeous_qkv(self, q, uc_context, context_k, context_v, mask): | ||
| 110 | h = self.heads | ||
| 111 | k = self.to_k(torch.cat([uc_context, context_k], dim=0)) | ||
| 112 | v = self.to_v(torch.cat([uc_context, context_v], dim=0)) | ||
| 113 | |||
| 114 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | ||
| 115 | |||
| 116 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | ||
| 117 | |||
| 118 | if exists(mask): | ||
| 119 | mask = rearrange(mask, 'b ... -> b (...)') | ||
| 120 | max_neg_value = -torch.finfo(sim.dtype).max | ||
| 121 | mask = repeat(mask, 'b j -> (b h) () j', h=h) | ||
| 122 | sim.masked_fill_(~mask, max_neg_value) | ||
| 123 | |||
| 124 | # attention, what we cannot get enough of | ||
| 125 | attn = sim.softmax(dim=-1) | ||
| 126 | |||
| 127 | out = einsum('b i j, b j d -> b i d', attn, v) | ||
| 128 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | ||
| 129 | return out | ||
| 130 | |||
| 131 | def get_kv(self, context): | ||
| 132 | return self.to_k(context), self.to_v(context) | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3da0169..e90528d 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -159,6 +159,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 159 | 159 | ||
| 160 | batch_size = len(prompt) | 160 | batch_size = len(prompt) |
| 161 | 161 | ||
| 162 | if negative_prompt is None: | ||
| 163 | negative_prompt = "" | ||
| 164 | |||
| 162 | if isinstance(negative_prompt, str): | 165 | if isinstance(negative_prompt, str): |
| 163 | negative_prompt = [negative_prompt] * batch_size | 166 | negative_prompt = [negative_prompt] * batch_size |
| 164 | 167 | ||
