summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py33
-rw-r--r--dreambooth.py34
-rw-r--r--models/attention/control.py168
-rw-r--r--models/attention/hook.py62
-rw-r--r--models/attention/structured.py132
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py3
6 files changed, 415 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py
index df15c5a..5144c0a 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -99,7 +99,7 @@ class CSVDataModule(pl.LightningDataModule):
99 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, 99 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size,
100 instance_identifier=self.instance_identifier, 100 instance_identifier=self.instance_identifier,
101 size=self.size, interpolation=self.interpolation, 101 size=self.size, interpolation=self.interpolation,
102 center_crop=self.center_crop, repeats=self.repeats) 102 center_crop=self.center_crop)
103 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 103 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
104 shuffle=True, pin_memory=True, collate_fn=self.collate_fn) 104 shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
105 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, 105 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size,
@@ -157,6 +157,17 @@ class CSVDataset(Dataset):
157 def __len__(self): 157 def __len__(self):
158 return math.ceil(self._length / self.batch_size) * self.batch_size 158 return math.ceil(self._length / self.batch_size) * self.batch_size
159 159
160 def get_image(self, path):
161 if path in self.image_cache:
162 return self.image_cache[path]
163
164 image = Image.open(path)
165 if not image.mode == "RGB":
166 image = image.convert("RGB")
167 self.image_cache[path] = image
168
169 return image
170
160 def get_example(self, i): 171 def get_example(self, i):
161 item = self.data[i % self.num_instance_images] 172 item = self.data[i % self.num_instance_images]
162 cache_key = f"{item.instance_image_path}_{item.class_image_path}" 173 cache_key = f"{item.instance_image_path}_{item.class_image_path}"
@@ -169,30 +180,18 @@ class CSVDataset(Dataset):
169 example["prompts"] = item.prompt 180 example["prompts"] = item.prompt
170 example["nprompts"] = item.nprompt 181 example["nprompts"] = item.nprompt
171 182
172 if item.instance_image_path in self.image_cache: 183 example["instance_images"] = self.get_image(item.instance_image_path)
173 instance_image = self.image_cache[item.instance_image_path]
174 else:
175 instance_image = Image.open(item.instance_image_path)
176 if not instance_image.mode == "RGB":
177 instance_image = instance_image.convert("RGB")
178 self.image_cache[item.instance_image_path] = instance_image
179
180 example["instance_images"] = instance_image
181 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 184 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
182 item.prompt.format(self.instance_identifier) 185 item.prompt.format(self.instance_identifier)
183 ) 186 )
184 187
185 if self.num_class_images != 0: 188 if self.num_class_images != 0:
186 class_image = Image.open(item.class_image_path) 189 example["class_images"] = self.get_image(item.class_image_path)
187 if not class_image.mode == "RGB":
188 class_image = class_image.convert("RGB")
189
190 example["class_images"] = class_image
191 example["class_prompt_ids"] = self.prompt_processor.get_input_ids( 190 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(
192 item.nprompt.format(self.class_identifier) 191 item.nprompt.format(self.class_identifier)
193 ) 192 )
194 193
195 self.cache[item.instance_image_path] = example 194 self.cache[cache_key] = example
196 return example 195 return example
197 196
198 def __getitem__(self, i): 197 def __getitem__(self, i):
@@ -204,7 +203,7 @@ class CSVDataset(Dataset):
204 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 203 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
205 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] 204 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
206 205
207 if self.class_identifier is not None: 206 if self.num_class_images != 0:
208 example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) 207 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
209 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] 208 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"]
210 209
diff --git a/dreambooth.py b/dreambooth.py
index 72c56cd..5c26f12 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -354,6 +354,8 @@ class Checkpointer:
354 text_encoder, 354 text_encoder,
355 output_dir: Path, 355 output_dir: Path,
356 instance_identifier, 356 instance_identifier,
357 placeholder_token,
358 placeholder_token_id,
357 sample_image_size, 359 sample_image_size,
358 sample_batches, 360 sample_batches,
359 sample_batch_size, 361 sample_batch_size,
@@ -368,12 +370,36 @@ class Checkpointer:
368 self.text_encoder = text_encoder 370 self.text_encoder = text_encoder
369 self.output_dir = output_dir 371 self.output_dir = output_dir
370 self.instance_identifier = instance_identifier 372 self.instance_identifier = instance_identifier
373 self.placeholder_token = placeholder_token
374 self.placeholder_token_id = placeholder_token_id
371 self.sample_image_size = sample_image_size 375 self.sample_image_size = sample_image_size
372 self.seed = seed or torch.random.seed() 376 self.seed = seed or torch.random.seed()
373 self.sample_batches = sample_batches 377 self.sample_batches = sample_batches
374 self.sample_batch_size = sample_batch_size 378 self.sample_batch_size = sample_batch_size
375 379
376 @torch.no_grad() 380 @torch.no_grad()
381 def save_embedding(self, step, postfix):
382 if self.placeholder_token_id is None:
383 return
384
385 print("Saving checkpoint for step %d..." % step)
386
387 checkpoints_path = self.output_dir.joinpath("checkpoints")
388 checkpoints_path.mkdir(parents=True, exist_ok=True)
389
390 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
391
392 # Save a checkpoint
393 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
394 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
395
396 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
397 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
398
399 del unwrapped
400 del learned_embeds
401
402 @torch.no_grad()
377 def save_model(self): 403 def save_model(self):
378 print("Saving model...") 404 print("Saving model...")
379 405
@@ -567,6 +593,8 @@ def main():
567 text_encoder.text_model.final_layer_norm.parameters(), 593 text_encoder.text_model.final_layer_norm.parameters(),
568 text_encoder.text_model.embeddings.position_embedding.parameters(), 594 text_encoder.text_model.embeddings.position_embedding.parameters(),
569 )) 595 ))
596 else:
597 placeholder_token_id = None
570 598
571 prompt_processor = PromptProcessor(tokenizer, text_encoder) 599 prompt_processor = PromptProcessor(tokenizer, text_encoder)
572 600
@@ -785,6 +813,8 @@ def main():
785 text_encoder=text_encoder, 813 text_encoder=text_encoder,
786 output_dir=basepath, 814 output_dir=basepath,
787 instance_identifier=instance_identifier, 815 instance_identifier=instance_identifier,
816 placeholder_token=args.placeholder_token,
817 placeholder_token_id=placeholder_token_id,
788 sample_image_size=args.sample_image_size, 818 sample_image_size=args.sample_image_size,
789 sample_batch_size=args.sample_batch_size, 819 sample_batch_size=args.sample_batch_size,
790 sample_batches=args.sample_batches, 820 sample_batches=args.sample_batches,
@@ -902,6 +932,7 @@ def main():
902 global_step += 1 932 global_step += 1
903 933
904 if global_step % args.sample_frequency == 0: 934 if global_step % args.sample_frequency == 0:
935 checkpointer.save_embedding(global_step, "training")
905 sample_checkpoint = True 936 sample_checkpoint = True
906 937
907 logs = { 938 logs = {
@@ -968,6 +999,7 @@ def main():
968 if min_val_loss > val_loss: 999 if min_val_loss > val_loss:
969 accelerator.print( 1000 accelerator.print(
970 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 1001 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
1002 checkpointer.save_embedding(global_step, "milestone")
971 min_val_loss = val_loss 1003 min_val_loss = val_loss
972 1004
973 if sample_checkpoint and accelerator.is_main_process: 1005 if sample_checkpoint and accelerator.is_main_process:
@@ -978,6 +1010,7 @@ def main():
978 # Create the pipeline using using the trained modules and save it. 1010 # Create the pipeline using using the trained modules and save it.
979 if accelerator.is_main_process: 1011 if accelerator.is_main_process:
980 print("Finished! Saving final checkpoint and resume state.") 1012 print("Finished! Saving final checkpoint and resume state.")
1013 checkpointer.save_embedding(global_step, "end")
981 checkpointer.save_model() 1014 checkpointer.save_model()
982 1015
983 accelerator.end_training() 1016 accelerator.end_training()
@@ -985,6 +1018,7 @@ def main():
985 except KeyboardInterrupt: 1018 except KeyboardInterrupt:
986 if accelerator.is_main_process: 1019 if accelerator.is_main_process:
987 print("Interrupted, saving checkpoint and resume state...") 1020 print("Interrupted, saving checkpoint and resume state...")
1021 checkpointer.save_embedding(global_step, "end")
988 checkpointer.save_model() 1022 checkpointer.save_model()
989 accelerator.end_training() 1023 accelerator.end_training()
990 quit() 1024 quit()
diff --git a/models/attention/control.py b/models/attention/control.py
new file mode 100644
index 0000000..248bd9f
--- /dev/null
+++ b/models/attention/control.py
@@ -0,0 +1,168 @@
1import torch
2import abc
3
4
5class AttentionControl(abc.ABC):
6 def step_callback(self, x_t):
7 return x_t
8
9 def between_steps(self):
10 return
11
12 @property
13 def num_uncond_att_layers(self):
14 return self.num_att_layers if LOW_RESOURCE else 0
15
16 @abc.abstractmethod
17 def forward(self, attn, is_cross: bool, place_in_unet: str):
18 raise NotImplementedError
19
20 def __call__(self, attn, is_cross: bool, place_in_unet: str):
21 if self.cur_att_layer >= self.num_uncond_att_layers:
22 if LOW_RESOURCE:
23 attn = self.forward(attn, is_cross, place_in_unet)
24 else:
25 h = attn.shape[0]
26 attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
27 self.cur_att_layer += 1
28 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
29 self.cur_att_layer = 0
30 self.cur_step += 1
31 self.between_steps()
32 return attn
33
34 def reset(self):
35 self.cur_step = 0
36 self.cur_att_layer = 0
37
38 def __init__(self):
39 self.cur_step = 0
40 self.num_att_layers = -1
41 self.cur_att_layer = 0
42
43
44class EmptyControl(AttentionControl):
45 def forward(self, attn, is_cross: bool, place_in_unet: str):
46 return attn
47
48
49class AttentionStore(AttentionControl):
50 @staticmethod
51 def get_empty_store():
52 return {"down_cross": [], "mid_cross": [], "up_cross": [],
53 "down_self": [], "mid_self": [], "up_self": []}
54
55 def forward(self, attn, is_cross: bool, place_in_unet: str):
56 key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
57 if attn.shape[1] <= 32 ** 2: # avoid memory overhead
58 self.step_store[key].append(attn)
59 return attn
60
61 def between_steps(self):
62 if len(self.attention_store) == 0:
63 self.attention_store = self.step_store
64 else:
65 for key in self.attention_store:
66 for i in range(len(self.attention_store[key])):
67 self.attention_store[key][i] += self.step_store[key][i]
68 self.step_store = self.get_empty_store()
69
70 def get_average_attention(self):
71 average_attention = {key: [item / self.cur_step for item in self.attention_store[key]]
72 for key in self.attention_store}
73 return average_attention
74
75 def reset(self):
76 super(AttentionStore, self).reset()
77 self.step_store = self.get_empty_store()
78 self.attention_store = {}
79
80 def __init__(self):
81 super(AttentionStore, self).__init__()
82 self.step_store = self.get_empty_store()
83 self.attention_store = {}
84
85
86class AttentionControlEdit(AttentionStore, abc.ABC):
87 def step_callback(self, x_t):
88 if self.local_blend is not None:
89 x_t = self.local_blend(x_t, self.attention_store)
90 return x_t
91
92 def replace_self_attention(self, attn_base, att_replace):
93 if att_replace.shape[2] <= 16 ** 2:
94 return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
95 else:
96 return att_replace
97
98 @abc.abstractmethod
99 def replace_cross_attention(self, attn_base, att_replace):
100 raise NotImplementedError
101
102 def forward(self, attn, is_cross: bool, place_in_unet: str):
103 super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
104 if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
105 h = attn.shape[0] // (self.batch_size)
106 attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
107 attn_base, attn_repalce = attn[0], attn[1:]
108 if is_cross:
109 alpha_words = self.cross_replace_alpha[self.cur_step]
110 attn_repalce_new = self.replace_cross_attention(
111 attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
112 attn[1:] = attn_repalce_new
113 else:
114 attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
115 attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
116 return attn
117
118 def __init__(self, prompts, num_steps: int,
119 cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
120 self_replace_steps: Union[float, Tuple[float, float]],
121 local_blend: Optional[LocalBlend]):
122 super(AttentionControlEdit, self).__init__()
123 self.batch_size = len(prompts)
124 self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(
125 prompts, num_steps, cross_replace_steps, tokenizer).to(device)
126 if type(self_replace_steps) is float:
127 self_replace_steps = 0, self_replace_steps
128 self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
129 self.local_blend = local_blend
130
131
132class AttentionReplace(AttentionControlEdit):
133 def replace_cross_attention(self, attn_base, att_replace):
134 return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
135
136 def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
137 local_blend: Optional[LocalBlend] = None):
138 super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
139 self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
140
141
142class AttentionRefine(AttentionControlEdit):
143 def replace_cross_attention(self, attn_base, att_replace):
144 attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
145 attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
146 return attn_replace
147
148 def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
149 local_blend: Optional[LocalBlend] = None):
150 super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
151 self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
152 self.mapper, alphas = self.mapper.to(device), alphas.to(device)
153 self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
154
155
156class AttentionReweight(AttentionControlEdit):
157 def replace_cross_attention(self, attn_base, att_replace):
158 if self.prev_controller is not None:
159 attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
160 attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
161 return attn_replace
162
163 def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
164 local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
165 super(AttentionReweight, self).__init__(prompts, num_steps,
166 cross_replace_steps, self_replace_steps, local_blend)
167 self.equalizer = equalizer.to(device)
168 self.prev_controller = controller
diff --git a/models/attention/hook.py b/models/attention/hook.py
new file mode 100644
index 0000000..903de02
--- /dev/null
+++ b/models/attention/hook.py
@@ -0,0 +1,62 @@
1import torch
2
3
4try:
5 import xformers.ops
6 xformers._is_functorch_available = True
7 MEM_EFFICIENT_ATTN = True
8except ImportError:
9 print("[!] Not using xformers memory efficient attention.")
10 MEM_EFFICIENT_ATTN = False
11
12
13def register_attention_control(model, controller):
14 def ca_forward(self, place_in_unet):
15 def forward(x, context=None, mask=None):
16 batch_size, sequence_length, dim = x.shape
17 h = self.heads
18 q = self.to_q(x)
19 is_cross = context is not None
20 context = context if is_cross else x
21 k = self.to_k(context)
22 v = self.to_v(context)
23 q = self.reshape_heads_to_batch_dim(q)
24 k = self.reshape_heads_to_batch_dim(k)
25 v = self.reshape_heads_to_batch_dim(v)
26
27 sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
28
29 if mask is not None:
30 mask = mask.reshape(batch_size, -1)
31 max_neg_value = -torch.finfo(sim.dtype).max
32 mask = mask[:, None, :].repeat(h, 1, 1)
33 sim.masked_fill_(~mask, max_neg_value)
34
35 # attention, what we cannot get enough of
36 attn = sim.softmax(dim=-1)
37 attn = controller(attn, is_cross, place_in_unet)
38 out = torch.einsum("b i j, b j d -> b i d", attn, v)
39 out = self.reshape_batch_dim_to_heads(out)
40 return self.to_out(out)
41
42 return forward
43
44 def register_recr(net_, count, place_in_unet):
45 if net_.__class__.__name__ == 'CrossAttention':
46 net_.forward = ca_forward(net_, place_in_unet)
47 return count + 1
48 elif hasattr(net_, 'children'):
49 for net__ in net_.children():
50 count = register_recr(net__, count, place_in_unet)
51 return count
52
53 cross_att_count = 0
54 sub_nets = model.unet.named_children()
55 for net in sub_nets:
56 if "down" in net[0]:
57 cross_att_count += register_recr(net[1], 0, "down")
58 elif "up" in net[0]:
59 cross_att_count += register_recr(net[1], 0, "up")
60 elif "mid" in net[0]:
61 cross_att_count += register_recr(net[1], 0, "mid")
62 controller.num_att_layers = cross_att_count
diff --git a/models/attention/structured.py b/models/attention/structured.py
new file mode 100644
index 0000000..24d889f
--- /dev/null
+++ b/models/attention/structured.py
@@ -0,0 +1,132 @@
1import torch
2
3from .control import AttentionControl
4
5
6class StructuredAttentionControl(AttentionControl):
7 def forward(self, attn, is_cross: bool, place_in_unet: str):
8 return attn
9
10 def forward(self, x, context=None, mask=None):
11 h = self.heads
12
13 q = self.to_q(x)
14
15 if isinstance(context, list):
16 if self.struct_attn:
17 out = self.struct_qkv(q, context, mask)
18 else:
19 context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context
20 out = self.normal_qkv(q, context, mask)
21 else:
22 context = default(context, x)
23 out = self.normal_qkv(q, context, mask)
24
25 return self.to_out(out)
26
27 def struct_qkv(self, q, context, mask):
28 """
29 context: list of [uc, list of conditional context]
30 """
31 uc_context = context[0]
32 context_k, context_v = context[1]['k'], context[1]['v']
33
34 if isinstance(context_k, list) and isinstance(context_v, list):
35 out = self.multi_qkv(q, uc_context, context_k, context_v, mask)
36 elif isinstance(context_k, torch.Tensor) and isinstance(context_v, torch.Tensor):
37 out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask)
38 else:
39 raise NotImplementedError
40
41 return out
42
43 def multi_qkv(self, q, uc_context, context_k, context_v, mask):
44 h = self.heads
45
46 assert uc_context.size(0) == context_k[0].size(0) == context_v[0].size(0)
47 true_bs = uc_context.size(0) * h
48
49 k_uc, v_uc = self.get_kv(uc_context)
50 k_c = [self.to_k(c_k) for c_k in context_k]
51 v_c = [self.to_v(c_v) for c_v in context_v]
52
53 q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
54
55 k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h)
56 v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h)
57
58 k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point
59 v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c]
60
61 # get composition
62 sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale
63 sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c]
64
65 attn_uc = sim_uc.softmax(dim=-1)
66 attn_c = [sim.softmax(dim=-1) for sim in sim_c]
67
68 # get uc output
69 out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc)
70
71 # get c output
72 if len(v_c) == 1:
73 out_c_collect = []
74 for attn in attn_c:
75 for v in v_c:
76 out_c_collect.append(einsum('b i j, b j d -> b i d', attn, v))
77 out_c = sum(out_c_collect) / len(out_c_collect)
78 else:
79 out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c)
80
81 out = torch.cat([out_uc, out_c], dim=0)
82 out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
83
84 return out
85
86 def normal_qkv(self, q, context, mask):
87 h = self.heads
88 k = self.to_k(context)
89 v = self.to_v(context)
90
91 q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
92
93 sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
94
95 if exists(mask):
96 mask = rearrange(mask, 'b ... -> b (...)')
97 max_neg_value = -torch.finfo(sim.dtype).max
98 mask = repeat(mask, 'b j -> (b h) () j', h=h)
99 sim.masked_fill_(~mask, max_neg_value)
100
101 # attention, what we cannot get enough of
102 attn = sim.softmax(dim=-1)
103
104 out = einsum('b i j, b j d -> b i d', attn, v)
105 out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
106
107 return out
108
109 def heterogeous_qkv(self, q, uc_context, context_k, context_v, mask):
110 h = self.heads
111 k = self.to_k(torch.cat([uc_context, context_k], dim=0))
112 v = self.to_v(torch.cat([uc_context, context_v], dim=0))
113
114 q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
115
116 sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
117
118 if exists(mask):
119 mask = rearrange(mask, 'b ... -> b (...)')
120 max_neg_value = -torch.finfo(sim.dtype).max
121 mask = repeat(mask, 'b j -> (b h) () j', h=h)
122 sim.masked_fill_(~mask, max_neg_value)
123
124 # attention, what we cannot get enough of
125 attn = sim.softmax(dim=-1)
126
127 out = einsum('b i j, b j d -> b i d', attn, v)
128 out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
129 return out
130
131 def get_kv(self, context):
132 return self.to_k(context), self.to_v(context)
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 3da0169..e90528d 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -159,6 +159,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
159 159
160 batch_size = len(prompt) 160 batch_size = len(prompt)
161 161
162 if negative_prompt is None:
163 negative_prompt = ""
164
162 if isinstance(negative_prompt, str): 165 if isinstance(negative_prompt, str):
163 negative_prompt = [negative_prompt] * batch_size 166 negative_prompt = [negative_prompt] * batch_size
164 167