diff options
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 262 |
1 files changed, 188 insertions, 74 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index aa446ec..16b8456 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -21,7 +21,9 @@ from diffusers import ( | |||
21 | LMSDiscreteScheduler, | 21 | LMSDiscreteScheduler, |
22 | PNDMScheduler, | 22 | PNDMScheduler, |
23 | ) | 23 | ) |
24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( |
25 | StableDiffusionPipelineOutput, | ||
26 | ) | ||
25 | from diffusers.utils import logging, randn_tensor | 27 | from diffusers.utils import logging, randn_tensor |
26 | from transformers import CLIPTextModel, CLIPTokenizer | 28 | from transformers import CLIPTextModel, CLIPTokenizer |
27 | 29 | ||
@@ -62,13 +64,35 @@ def gaussian_blur_2d(img, kernel_size, sigma): | |||
62 | return img | 64 | return img |
63 | 65 | ||
64 | 66 | ||
67 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | ||
68 | """ | ||
69 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | ||
70 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | ||
71 | """ | ||
72 | std_text = noise_pred_text.std( | ||
73 | dim=list(range(1, noise_pred_text.ndim)), keepdim=True | ||
74 | ) | ||
75 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | ||
76 | # rescale the results from guidance (fixes overexposure) | ||
77 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | ||
78 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | ||
79 | noise_cfg = ( | ||
80 | guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | ||
81 | ) | ||
82 | return noise_cfg | ||
83 | |||
84 | |||
65 | class CrossAttnStoreProcessor: | 85 | class CrossAttnStoreProcessor: |
66 | def __init__(self): | 86 | def __init__(self): |
67 | self.attention_probs = None | 87 | self.attention_probs = None |
68 | 88 | ||
69 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): | 89 | def __call__( |
90 | self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None | ||
91 | ): | ||
70 | batch_size, sequence_length, _ = hidden_states.shape | 92 | batch_size, sequence_length, _ = hidden_states.shape |
71 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | 93 | attention_mask = attn.prepare_attention_mask( |
94 | attention_mask, sequence_length, batch_size | ||
95 | ) | ||
72 | query = attn.to_q(hidden_states) | 96 | query = attn.to_q(hidden_states) |
73 | 97 | ||
74 | if encoder_hidden_states is None: | 98 | if encoder_hidden_states is None: |
@@ -113,7 +137,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
113 | ): | 137 | ): |
114 | super().__init__() | 138 | super().__init__() |
115 | 139 | ||
116 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | 140 | if ( |
141 | hasattr(scheduler.config, "steps_offset") | ||
142 | and scheduler.config.steps_offset != 1 | ||
143 | ): | ||
117 | warnings.warn( | 144 | warnings.warn( |
118 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | 145 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" |
119 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | 146 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " |
@@ -179,7 +206,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
179 | 206 | ||
180 | device = torch.device("cuda") | 207 | device = torch.device("cuda") |
181 | 208 | ||
182 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: | 209 | for cpu_offloaded_model in [ |
210 | self.unet, | ||
211 | self.text_encoder, | ||
212 | self.vae, | ||
213 | self.safety_checker, | ||
214 | ]: | ||
183 | if cpu_offloaded_model is not None: | 215 | if cpu_offloaded_model is not None: |
184 | cpu_offload(cpu_offloaded_model, device) | 216 | cpu_offload(cpu_offloaded_model, device) |
185 | 217 | ||
@@ -223,35 +255,47 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
223 | width: int, | 255 | width: int, |
224 | height: int, | 256 | height: int, |
225 | strength: float, | 257 | strength: float, |
226 | callback_steps: Optional[int] | 258 | callback_steps: Optional[int], |
227 | ): | 259 | ): |
228 | if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): | 260 | if isinstance(prompt, str) or ( |
261 | isinstance(prompt, list) and isinstance(prompt[0], int) | ||
262 | ): | ||
229 | prompt = [prompt] | 263 | prompt = [prompt] |
230 | 264 | ||
231 | if negative_prompt is None: | 265 | if negative_prompt is None: |
232 | negative_prompt = "" | 266 | negative_prompt = "" |
233 | 267 | ||
234 | if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): | 268 | if isinstance(negative_prompt, str) or ( |
269 | isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int) | ||
270 | ): | ||
235 | negative_prompt = [negative_prompt] * len(prompt) | 271 | negative_prompt = [negative_prompt] * len(prompt) |
236 | 272 | ||
237 | if not isinstance(prompt, list): | 273 | if not isinstance(prompt, list): |
238 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | 274 | raise ValueError( |
275 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" | ||
276 | ) | ||
239 | 277 | ||
240 | if not isinstance(negative_prompt, list): | 278 | if not isinstance(negative_prompt, list): |
241 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | 279 | raise ValueError( |
280 | f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}" | ||
281 | ) | ||
242 | 282 | ||
243 | if len(negative_prompt) != len(prompt): | 283 | if len(negative_prompt) != len(prompt): |
244 | raise ValueError( | 284 | raise ValueError( |
245 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | 285 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}" |
286 | ) | ||
246 | 287 | ||
247 | if strength < 0 or strength > 1: | 288 | if strength < 0 or strength > 1: |
248 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | 289 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") |
249 | 290 | ||
250 | if height % 8 != 0 or width % 8 != 0: | 291 | if height % 8 != 0 or width % 8 != 0: |
251 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 292 | raise ValueError( |
293 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}." | ||
294 | ) | ||
252 | 295 | ||
253 | if (callback_steps is None) or ( | 296 | if (callback_steps is None) or ( |
254 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) | 297 | callback_steps is not None |
298 | and (not isinstance(callback_steps, int) or callback_steps <= 0) | ||
255 | ): | 299 | ): |
256 | raise ValueError( | 300 | raise ValueError( |
257 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | 301 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
@@ -266,7 +310,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
266 | negative_prompt: Union[List[str], List[List[int]]], | 310 | negative_prompt: Union[List[str], List[List[int]]], |
267 | num_images_per_prompt: int, | 311 | num_images_per_prompt: int, |
268 | do_classifier_free_guidance: bool, | 312 | do_classifier_free_guidance: bool, |
269 | device | 313 | device, |
270 | ): | 314 | ): |
271 | if isinstance(prompt[0], str): | 315 | if isinstance(prompt[0], str): |
272 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids | 316 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids |
@@ -277,7 +321,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
277 | 321 | ||
278 | if do_classifier_free_guidance: | 322 | if do_classifier_free_guidance: |
279 | if isinstance(prompt[0], str): | 323 | if isinstance(prompt[0], str): |
280 | unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids | 324 | unconditional_input_ids = self.tokenizer( |
325 | negative_prompt, padding="do_not_pad" | ||
326 | ).input_ids | ||
281 | else: | 327 | else: |
282 | unconditional_input_ids = negative_prompt | 328 | unconditional_input_ids = negative_prompt |
283 | unconditional_input_ids *= num_images_per_prompt | 329 | unconditional_input_ids *= num_images_per_prompt |
@@ -286,12 +332,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
286 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) | 332 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) |
287 | text_input_ids = text_inputs.input_ids | 333 | text_input_ids = text_inputs.input_ids |
288 | 334 | ||
289 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | 335 | if ( |
336 | hasattr(self.text_encoder.config, "use_attention_mask") | ||
337 | and self.text_encoder.config.use_attention_mask | ||
338 | ): | ||
290 | attention_mask = text_inputs.attention_mask.to(device) | 339 | attention_mask = text_inputs.attention_mask.to(device) |
291 | else: | 340 | else: |
292 | attention_mask = None | 341 | attention_mask = None |
293 | 342 | ||
294 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) | 343 | prompt_embeds = get_extended_embeddings( |
344 | self.text_encoder, text_input_ids.to(device), attention_mask | ||
345 | ) | ||
295 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | 346 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) |
296 | 347 | ||
297 | return prompt_embeds | 348 | return prompt_embeds |
@@ -301,25 +352,21 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
301 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | 352 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
302 | 353 | ||
303 | t_start = max(num_inference_steps - init_timestep, 0) | 354 | t_start = max(num_inference_steps - init_timestep, 0) |
304 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] | 355 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] |
305 | 356 | ||
306 | timesteps = timesteps.to(device) | 357 | timesteps = timesteps.to(device) |
307 | 358 | ||
308 | return timesteps, num_inference_steps - t_start | 359 | return timesteps, num_inference_steps - t_start |
309 | 360 | ||
310 | def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): | 361 | def prepare_latents_from_image( |
311 | offset_image = perlin_noise( | 362 | self, |
312 | (batch_size, 1, width, height), | 363 | init_image, |
313 | res=1, | 364 | timestep, |
314 | generator=generator, | 365 | batch_size, |
315 | dtype=dtype, | 366 | dtype, |
316 | device=device | 367 | device, |
317 | ) | 368 | generator=None, |
318 | offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) | 369 | ): |
319 | offset_latents = self.vae.config.scaling_factor * offset_latents | ||
320 | return offset_latents | ||
321 | |||
322 | def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None): | ||
323 | init_image = init_image.to(device=device, dtype=dtype) | 370 | init_image = init_image.to(device=device, dtype=dtype) |
324 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) | 371 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) |
325 | latents = self.vae.config.scaling_factor * latents | 372 | latents = self.vae.config.scaling_factor * latents |
@@ -333,20 +380,32 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
333 | latents = torch.cat([latents] * batch_multiplier, dim=0) | 380 | latents = torch.cat([latents] * batch_multiplier, dim=0) |
334 | 381 | ||
335 | # add noise to latents using the timesteps | 382 | # add noise to latents using the timesteps |
336 | noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) | 383 | noise = torch.randn( |
337 | 384 | latents.shape, generator=generator, device=device, dtype=dtype | |
338 | if brightness_offset != 0: | 385 | ) |
339 | noise += brightness_offset * self.prepare_brightness_offset( | ||
340 | batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator | ||
341 | ) | ||
342 | 386 | ||
343 | # get latents | 387 | # get latents |
344 | latents = self.scheduler.add_noise(latents, noise, timestep) | 388 | latents = self.scheduler.add_noise(latents, noise, timestep) |
345 | 389 | ||
346 | return latents | 390 | return latents |
347 | 391 | ||
348 | def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): | 392 | def prepare_latents( |
349 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | 393 | self, |
394 | batch_size, | ||
395 | num_channels_latents, | ||
396 | height, | ||
397 | width, | ||
398 | dtype, | ||
399 | device, | ||
400 | generator, | ||
401 | latents=None, | ||
402 | ): | ||
403 | shape = ( | ||
404 | batch_size, | ||
405 | num_channels_latents, | ||
406 | height // self.vae_scale_factor, | ||
407 | width // self.vae_scale_factor, | ||
408 | ) | ||
350 | if isinstance(generator, list) and len(generator) != batch_size: | 409 | if isinstance(generator, list) and len(generator) != batch_size: |
351 | raise ValueError( | 410 | raise ValueError( |
352 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | 411 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
@@ -354,15 +413,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
354 | ) | 413 | ) |
355 | 414 | ||
356 | if latents is None: | 415 | if latents is None: |
357 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | 416 | latents = randn_tensor( |
417 | shape, generator=generator, device=device, dtype=dtype | ||
418 | ) | ||
358 | else: | 419 | else: |
359 | latents = latents.to(device) | 420 | latents = latents.to(device) |
360 | 421 | ||
361 | if brightness_offset != 0: | ||
362 | latents += brightness_offset * self.prepare_brightness_offset( | ||
363 | batch_size, height, width, dtype, device, generator | ||
364 | ) | ||
365 | |||
366 | # scale the initial noise by the standard deviation required by the scheduler | 422 | # scale the initial noise by the standard deviation required by the scheduler |
367 | latents = latents * self.scheduler.init_noise_sigma | 423 | latents = latents * self.scheduler.init_noise_sigma |
368 | return latents | 424 | return latents |
@@ -373,13 +429,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
373 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | 429 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
374 | # and should be between [0, 1] | 430 | # and should be between [0, 1] |
375 | 431 | ||
376 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 432 | accepts_eta = "eta" in set( |
433 | inspect.signature(self.scheduler.step).parameters.keys() | ||
434 | ) | ||
377 | extra_step_kwargs = {} | 435 | extra_step_kwargs = {} |
378 | if accepts_eta: | 436 | if accepts_eta: |
379 | extra_step_kwargs["eta"] = eta | 437 | extra_step_kwargs["eta"] = eta |
380 | 438 | ||
381 | # check if the scheduler accepts generator | 439 | # check if the scheduler accepts generator |
382 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 440 | accepts_generator = "generator" in set( |
441 | inspect.signature(self.scheduler.step).parameters.keys() | ||
442 | ) | ||
383 | if accepts_generator: | 443 | if accepts_generator: |
384 | extra_step_kwargs["generator"] = generator | 444 | extra_step_kwargs["generator"] = generator |
385 | return extra_step_kwargs | 445 | return extra_step_kwargs |
@@ -396,7 +456,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
396 | def __call__( | 456 | def __call__( |
397 | self, | 457 | self, |
398 | prompt: Union[str, List[str], List[int], List[List[int]]], | 458 | prompt: Union[str, List[str], List[int], List[List[int]]], |
399 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, | 459 | negative_prompt: Optional[ |
460 | Union[str, List[str], List[int], List[List[int]]] | ||
461 | ] = None, | ||
400 | num_images_per_prompt: int = 1, | 462 | num_images_per_prompt: int = 1, |
401 | strength: float = 1.0, | 463 | strength: float = 1.0, |
402 | height: Optional[int] = None, | 464 | height: Optional[int] = None, |
@@ -407,12 +469,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
407 | eta: float = 0.0, | 469 | eta: float = 0.0, |
408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 470 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 471 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
410 | brightness_offset: Union[float, torch.FloatTensor] = 0, | ||
411 | output_type: str = "pil", | 472 | output_type: str = "pil", |
412 | return_dict: bool = True, | 473 | return_dict: bool = True, |
413 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 474 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
414 | callback_steps: int = 1, | 475 | callback_steps: int = 1, |
415 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, | 476 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
477 | guidance_rescale: float = 0.0, | ||
416 | ): | 478 | ): |
417 | r""" | 479 | r""" |
418 | Function invoked when calling the pipeline for generation. | 480 | Function invoked when calling the pipeline for generation. |
@@ -472,7 +534,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
472 | width = width or self.unet.config.sample_size * self.vae_scale_factor | 534 | width = width or self.unet.config.sample_size * self.vae_scale_factor |
473 | 535 | ||
474 | # 1. Check inputs. Raise error if not correct | 536 | # 1. Check inputs. Raise error if not correct |
475 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) | 537 | prompt, negative_prompt = self.check_inputs( |
538 | prompt, negative_prompt, width, height, strength, callback_steps | ||
539 | ) | ||
476 | 540 | ||
477 | # 2. Define call parameters | 541 | # 2. Define call parameters |
478 | batch_size = len(prompt) | 542 | batch_size = len(prompt) |
@@ -488,7 +552,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
488 | negative_prompt, | 552 | negative_prompt, |
489 | num_images_per_prompt, | 553 | num_images_per_prompt, |
490 | do_classifier_free_guidance, | 554 | do_classifier_free_guidance, |
491 | device | 555 | device, |
492 | ) | 556 | ) |
493 | 557 | ||
494 | # 4. Prepare latent variables | 558 | # 4. Prepare latent variables |
@@ -497,7 +561,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
497 | 561 | ||
498 | # 5. Prepare timesteps | 562 | # 5. Prepare timesteps |
499 | self.scheduler.set_timesteps(num_inference_steps, device=device) | 563 | self.scheduler.set_timesteps(num_inference_steps, device=device) |
500 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | 564 | timesteps, num_inference_steps = self.get_timesteps( |
565 | num_inference_steps, strength, device | ||
566 | ) | ||
501 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 567 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
502 | 568 | ||
503 | # 6. Prepare latent variables | 569 | # 6. Prepare latent variables |
@@ -506,7 +572,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
506 | image, | 572 | image, |
507 | latent_timestep, | 573 | latent_timestep, |
508 | batch_size * num_images_per_prompt, | 574 | batch_size * num_images_per_prompt, |
509 | brightness_offset, | ||
510 | prompt_embeds.dtype, | 575 | prompt_embeds.dtype, |
511 | device, | 576 | device, |
512 | generator, | 577 | generator, |
@@ -517,7 +582,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
517 | num_channels_latents, | 582 | num_channels_latents, |
518 | height, | 583 | height, |
519 | width, | 584 | width, |
520 | brightness_offset, | ||
521 | prompt_embeds.dtype, | 585 | prompt_embeds.dtype, |
522 | device, | 586 | device, |
523 | generator, | 587 | generator, |
@@ -530,14 +594,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
530 | # 8. Denoising loo | 594 | # 8. Denoising loo |
531 | if do_self_attention_guidance: | 595 | if do_self_attention_guidance: |
532 | store_processor = CrossAttnStoreProcessor() | 596 | store_processor = CrossAttnStoreProcessor() |
533 | self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor | 597 | self.unet.mid_block.attentions[0].transformer_blocks[ |
598 | 0 | ||
599 | ].attn1.processor = store_processor | ||
534 | 600 | ||
535 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | 601 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
536 | with self.progress_bar(total=num_inference_steps) as progress_bar: | 602 | with self.progress_bar(total=num_inference_steps) as progress_bar: |
537 | for i, t in enumerate(timesteps): | 603 | for i, t in enumerate(timesteps): |
538 | # expand the latents if we are doing classifier free guidance | 604 | # expand the latents if we are doing classifier free guidance |
539 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 605 | latent_model_input = ( |
540 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 606 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
607 | ) | ||
608 | latent_model_input = self.scheduler.scale_model_input( | ||
609 | latent_model_input, t | ||
610 | ) | ||
541 | 611 | ||
542 | # predict the noise residual | 612 | # predict the noise residual |
543 | noise_pred = self.unet( | 613 | noise_pred = self.unet( |
@@ -551,7 +621,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
551 | # perform guidance | 621 | # perform guidance |
552 | if do_classifier_free_guidance: | 622 | if do_classifier_free_guidance: |
553 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 623 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
554 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 624 | noise_pred = noise_pred_uncond + guidance_scale * ( |
625 | noise_pred_text - noise_pred_uncond | ||
626 | ) | ||
627 | noise_pred = rescale_noise_cfg( | ||
628 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale | ||
629 | ) | ||
555 | 630 | ||
556 | if do_self_attention_guidance: | 631 | if do_self_attention_guidance: |
557 | # classifier-free guidance produces two chunks of attention map | 632 | # classifier-free guidance produces two chunks of attention map |
@@ -561,15 +636,24 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
561 | # DDIM-like prediction of x0 | 636 | # DDIM-like prediction of x0 |
562 | pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) | 637 | pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) |
563 | # get the stored attention maps | 638 | # get the stored attention maps |
564 | uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) | 639 | uncond_attn, cond_attn = store_processor.attention_probs.chunk( |
640 | 2 | ||
641 | ) | ||
565 | # self-attention-based degrading of latents | 642 | # self-attention-based degrading of latents |
566 | degraded_latents = self.sag_masking( | 643 | degraded_latents = self.sag_masking( |
567 | pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) | 644 | pred_x0, |
645 | uncond_attn, | ||
646 | t, | ||
647 | self.pred_epsilon(latents, noise_pred_uncond, t), | ||
568 | ) | 648 | ) |
569 | uncond_emb, _ = prompt_embeds.chunk(2) | 649 | uncond_emb, _ = prompt_embeds.chunk(2) |
570 | # forward and give guidance | 650 | # forward and give guidance |
571 | degraded_pred = self.unet( | 651 | degraded_pred = self.unet( |
572 | degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] | 652 | degraded_latents, |
653 | t, | ||
654 | encoder_hidden_states=uncond_emb, | ||
655 | return_dict=False, | ||
656 | )[0] | ||
573 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) | 657 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) |
574 | else: | 658 | else: |
575 | # DDIM-like prediction of x0 | 659 | # DDIM-like prediction of x0 |
@@ -578,18 +662,29 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
578 | cond_attn = store_processor.attention_probs | 662 | cond_attn = store_processor.attention_probs |
579 | # self-attention-based degrading of latents | 663 | # self-attention-based degrading of latents |
580 | degraded_latents = self.sag_masking( | 664 | degraded_latents = self.sag_masking( |
581 | pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) | 665 | pred_x0, |
666 | cond_attn, | ||
667 | t, | ||
668 | self.pred_epsilon(latents, noise_pred, t), | ||
582 | ) | 669 | ) |
583 | # forward and give guidance | 670 | # forward and give guidance |
584 | degraded_pred = self.unet( | 671 | degraded_pred = self.unet( |
585 | degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] | 672 | degraded_latents, |
673 | t, | ||
674 | encoder_hidden_states=prompt_embeds, | ||
675 | return_dict=False, | ||
676 | )[0] | ||
586 | noise_pred += sag_scale * (noise_pred - degraded_pred) | 677 | noise_pred += sag_scale * (noise_pred - degraded_pred) |
587 | 678 | ||
588 | # compute the previous noisy sample x_t -> x_t-1 | 679 | # compute the previous noisy sample x_t -> x_t-1 |
589 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | 680 | latents = self.scheduler.step( |
681 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False | ||
682 | )[0] | ||
590 | 683 | ||
591 | # call the callback, if provided | 684 | # call the callback, if provided |
592 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | 685 | if i == len(timesteps) - 1 or ( |
686 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | ||
687 | ): | ||
593 | progress_bar.update() | 688 | progress_bar.update() |
594 | if callback is not None and i % callback_steps == 0: | 689 | if callback is not None and i % callback_steps == 0: |
595 | callback(i, t, latents) | 690 | callback(i, t, latents) |
@@ -615,7 +710,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
615 | if not return_dict: | 710 | if not return_dict: |
616 | return (image, has_nsfw_concept) | 711 | return (image, has_nsfw_concept) |
617 | 712 | ||
618 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | 713 | return StableDiffusionPipelineOutput( |
714 | images=image, nsfw_content_detected=has_nsfw_concept | ||
715 | ) | ||
619 | 716 | ||
620 | # Self-Attention-Guided (SAG) Stable Diffusion | 717 | # Self-Attention-Guided (SAG) Stable Diffusion |
621 | 718 | ||
@@ -632,16 +729,23 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
632 | attn_map = attn_map.reshape(b, h, hw1, hw2) | 729 | attn_map = attn_map.reshape(b, h, hw1, hw2) |
633 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 | 730 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 |
634 | attn_mask = ( | 731 | attn_mask = ( |
635 | attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) | 732 | attn_mask.reshape(b, map_size, map_size) |
733 | .unsqueeze(1) | ||
734 | .repeat(1, latent_channel, 1, 1) | ||
735 | .type(attn_map.dtype) | ||
636 | ) | 736 | ) |
637 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) | 737 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) |
638 | 738 | ||
639 | # Blur according to the self-attention mask | 739 | # Blur according to the self-attention mask |
640 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) | 740 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) |
641 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) | 741 | degraded_latents = degraded_latents * attn_mask + original_latents * ( |
742 | 1 - attn_mask | ||
743 | ) | ||
642 | 744 | ||
643 | # Noise it again to match the noise level | 745 | # Noise it again to match the noise level |
644 | degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) | 746 | degraded_latents = self.scheduler.add_noise( |
747 | degraded_latents, noise=eps, timesteps=t | ||
748 | ) | ||
645 | 749 | ||
646 | return degraded_latents | 750 | return degraded_latents |
647 | 751 | ||
@@ -652,13 +756,19 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
652 | 756 | ||
653 | beta_prod_t = 1 - alpha_prod_t | 757 | beta_prod_t = 1 - alpha_prod_t |
654 | if self.scheduler.config.prediction_type == "epsilon": | 758 | if self.scheduler.config.prediction_type == "epsilon": |
655 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | 759 | pred_original_sample = ( |
760 | sample - beta_prod_t ** (0.5) * model_output | ||
761 | ) / alpha_prod_t ** (0.5) | ||
656 | elif self.scheduler.config.prediction_type == "sample": | 762 | elif self.scheduler.config.prediction_type == "sample": |
657 | pred_original_sample = model_output | 763 | pred_original_sample = model_output |
658 | elif self.scheduler.config.prediction_type == "v_prediction": | 764 | elif self.scheduler.config.prediction_type == "v_prediction": |
659 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | 765 | pred_original_sample = (alpha_prod_t**0.5) * sample - ( |
766 | beta_prod_t**0.5 | ||
767 | ) * model_output | ||
660 | # predict V | 768 | # predict V |
661 | model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | 769 | model_output = (alpha_prod_t**0.5) * model_output + ( |
770 | beta_prod_t**0.5 | ||
771 | ) * sample | ||
662 | else: | 772 | else: |
663 | raise ValueError( | 773 | raise ValueError( |
664 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | 774 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," |
@@ -674,9 +784,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
674 | if self.scheduler.config.prediction_type == "epsilon": | 784 | if self.scheduler.config.prediction_type == "epsilon": |
675 | pred_eps = model_output | 785 | pred_eps = model_output |
676 | elif self.scheduler.config.prediction_type == "sample": | 786 | elif self.scheduler.config.prediction_type == "sample": |
677 | pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) | 787 | pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / ( |
788 | beta_prod_t**0.5 | ||
789 | ) | ||
678 | elif self.scheduler.config.prediction_type == "v_prediction": | 790 | elif self.scheduler.config.prediction_type == "v_prediction": |
679 | pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output | 791 | pred_eps = (beta_prod_t**0.5) * sample + ( |
792 | alpha_prod_t**0.5 | ||
793 | ) * model_output | ||
680 | else: | 794 | else: |
681 | raise ValueError( | 795 | raise ValueError( |
682 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | 796 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," |