summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py169
1 files changed, 162 insertions, 7 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index dab7878..66566b0 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -1,9 +1,11 @@
1import inspect 1import inspect
2import warnings 2import warnings
3import math
3from typing import List, Dict, Any, Optional, Union, Callable 4from typing import List, Dict, Any, Optional, Union, Callable
4 5
5import numpy as np 6import numpy as np
6import torch 7import torch
8import torchvision.transforms as T
7import PIL 9import PIL
8 10
9from diffusers.configuration_utils import FrozenDict 11from diffusers.configuration_utils import FrozenDict
@@ -37,6 +39,35 @@ def preprocess(image):
37 return 2.0 * image - 1.0 39 return 2.0 * image - 1.0
38 40
39 41
42class CrossAttnStoreProcessor:
43 def __init__(self):
44 self.attention_probs = None
45
46 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
47 batch_size, sequence_length, _ = hidden_states.shape
48 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
49
50 query = attn.to_q(hidden_states)
51 query = attn.head_to_batch_dim(query)
52
53 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
54 key = attn.to_k(encoder_hidden_states)
55 value = attn.to_v(encoder_hidden_states)
56 key = attn.head_to_batch_dim(key)
57 value = attn.head_to_batch_dim(value)
58
59 self.attention_probs = attn.get_attention_scores(query, key, attention_mask)
60 hidden_states = torch.bmm(self.attention_probs, value)
61 hidden_states = attn.batch_to_head_dim(hidden_states)
62
63 # linear proj
64 hidden_states = attn.to_out[0](hidden_states)
65 # dropout
66 hidden_states = attn.to_out[1](hidden_states)
67
68 return hidden_states
69
70
40class VlpnStableDiffusion(DiffusionPipeline): 71class VlpnStableDiffusion(DiffusionPipeline):
41 def __init__( 72 def __init__(
42 self, 73 self,
@@ -233,9 +264,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
233 else: 264 else:
234 attention_mask = None 265 attention_mask = None
235 266
236 text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) 267 prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask)
237 268
238 return text_embeddings 269 return prompt_embeds
239 270
240 def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): 271 def get_timesteps(self, latents_are_image, num_inference_steps, strength, device):
241 if latents_are_image: 272 if latents_are_image:
@@ -330,6 +361,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
330 width: Optional[int] = None, 361 width: Optional[int] = None,
331 num_inference_steps: int = 50, 362 num_inference_steps: int = 50,
332 guidance_scale: float = 7.5, 363 guidance_scale: float = 7.5,
364 sag_scale: float = 0.75,
333 eta: float = 0.0, 365 eta: float = 0.0,
334 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 366 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
335 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 367 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
@@ -403,10 +435,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
403 batch_size = len(prompt) 435 batch_size = len(prompt)
404 device = self.execution_device 436 device = self.execution_device
405 do_classifier_free_guidance = guidance_scale > 1.0 437 do_classifier_free_guidance = guidance_scale > 1.0
438 do_self_attention_guidance = sag_scale > 0.0
406 latents_are_image = isinstance(image, PIL.Image.Image) 439 latents_are_image = isinstance(image, PIL.Image.Image)
407 440
408 # 3. Encode input prompt 441 # 3. Encode input prompt
409 text_embeddings = self.encode_prompt( 442 prompt_embeds = self.encode_prompt(
410 prompt, 443 prompt,
411 negative_prompt, 444 negative_prompt,
412 num_images_per_prompt, 445 num_images_per_prompt,
@@ -427,7 +460,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
427 image, 460 image,
428 latent_timestep, 461 latent_timestep,
429 batch_size * num_images_per_prompt, 462 batch_size * num_images_per_prompt,
430 text_embeddings.dtype, 463 prompt_embeds.dtype,
431 device, 464 device,
432 generator 465 generator
433 ) 466 )
@@ -437,7 +470,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
437 num_channels_latents, 470 num_channels_latents,
438 height, 471 height,
439 width, 472 width,
440 text_embeddings.dtype, 473 prompt_embeds.dtype,
441 device, 474 device,
442 generator, 475 generator,
443 image, 476 image,
@@ -446,7 +479,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
446 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 479 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
447 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 480 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
448 481
449 # 7. Denoising loop 482 # 7. Denoising loo
483 if do_self_attention_guidance:
484 store_processor = CrossAttnStoreProcessor()
485 self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
486
450 num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 487 num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
451 with self.progress_bar(total=num_inference_steps) as progress_bar: 488 with self.progress_bar(total=num_inference_steps) as progress_bar:
452 for i, t in enumerate(timesteps): 489 for i, t in enumerate(timesteps):
@@ -458,7 +495,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
458 noise_pred = self.unet( 495 noise_pred = self.unet(
459 latent_model_input, 496 latent_model_input,
460 t, 497 t,
461 encoder_hidden_states=text_embeddings, 498 encoder_hidden_states=prompt_embeds,
462 cross_attention_kwargs=cross_attention_kwargs, 499 cross_attention_kwargs=cross_attention_kwargs,
463 ).sample 500 ).sample
464 501
@@ -467,6 +504,36 @@ class VlpnStableDiffusion(DiffusionPipeline):
467 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 504 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
468 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 505 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
469 506
507 if do_self_attention_guidance:
508 # classifier-free guidance produces two chunks of attention map
509 # and we only use unconditional one according to equation (24)
510 # in https://arxiv.org/pdf/2210.00939.pdf
511 if do_classifier_free_guidance:
512 # DDIM-like prediction of x0
513 pred_x0 = self.pred_x0_from_eps(latents, noise_pred_uncond, t)
514 # get the stored attention maps
515 uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
516 # self-attention-based degrading of latents
517 degraded_latents = self.sag_masking(
518 pred_x0, uncond_attn, t, self.pred_eps_from_noise(latents, noise_pred_uncond, t)
519 )
520 uncond_emb, _ = prompt_embeds.chunk(2)
521 # forward and give guidance
522 degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample
523 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
524 else:
525 # DDIM-like prediction of x0
526 pred_x0 = self.pred_x0_from_eps(latents, noise_pred, t)
527 # get the stored attention maps
528 cond_attn = store_processor.attention_probs
529 # self-attention-based degrading of latents
530 degraded_latents = self.sag_masking(
531 pred_x0, cond_attn, t, self.pred_eps_from_noise(latents, noise_pred, t)
532 )
533 # forward and give guidance
534 degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
535 noise_pred += sag_scale * (noise_pred - degraded_pred)
536
470 # compute the previous noisy sample x_t -> x_t-1 537 # compute the previous noisy sample x_t -> x_t-1
471 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 538 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
472 539
@@ -490,3 +557,91 @@ class VlpnStableDiffusion(DiffusionPipeline):
490 return (image, has_nsfw_concept) 557 return (image, has_nsfw_concept)
491 558
492 return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 559 return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
560
561 # Self-Attention-Guided (SAG) Stable Diffusion
562
563 def sag_masking(self, original_latents, attn_map, t, eps):
564 # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
565 bh, hw1, hw2 = attn_map.shape
566 b, latent_channel, latent_h, latent_w = original_latents.shape
567 h = self.unet.attention_head_dim
568 if isinstance(h, list):
569 h = h[-1]
570 map_size = math.isqrt(hw1)
571
572 # Produce attention mask
573 attn_map = attn_map.reshape(b, h, hw1, hw2)
574 attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
575 attn_mask = (
576 attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype)
577 )
578 attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w))
579
580 # Blur according to the self-attention mask
581 transform = T.GaussianBlur(kernel_size=9, sigma=1.0)
582 degraded_latents = transform(original_latents)
583 degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask)
584
585 # Noise it again to match the noise level
586 degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t)
587
588 return degraded_latents
589
590 # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step
591 def pred_x0_from_eps(self, sample, model_output, timestep):
592 # 1. get previous step value (=t-1)
593 # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
594
595 # 2. compute alphas, betas
596 alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
597 # alpha_prod_t_prev = (
598 # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
599 # )
600
601 beta_prod_t = 1 - alpha_prod_t
602 # 3. compute predicted original sample from predicted noise also called
603 # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
604 if self.scheduler.config.prediction_type == "epsilon":
605 pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
606 elif self.scheduler.config.prediction_type == "sample":
607 pred_original_sample = model_output
608 elif self.scheduler.config.prediction_type == "v_prediction":
609 pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
610 # predict V
611 model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
612 else:
613 raise ValueError(
614 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
615 " or `v_prediction`"
616 )
617 # # 4. Clip "predicted x_0"
618 # if self.scheduler.config.clip_sample:
619 # pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
620
621 return pred_original_sample
622
623 def pred_eps_from_noise(self, sample, model_output, timestep):
624 # 1. get previous step value (=t-1)
625 # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
626
627 # 2. compute alphas, betas
628 alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
629 # alpha_prod_t_prev = (
630 # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
631 # )
632
633 beta_prod_t = 1 - alpha_prod_t
634 # 3. compute predicted eps from model output
635 if self.scheduler.config.prediction_type == "epsilon":
636 pred_eps = model_output
637 elif self.scheduler.config.prediction_type == "sample":
638 pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5)
639 elif self.scheduler.config.prediction_type == "v_prediction":
640 pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output
641 else:
642 raise ValueError(
643 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
644 " or `v_prediction`"
645 )
646
647 return pred_eps