diff options
author | Volpeon <git@volpeon.ink> | 2023-02-08 11:38:56 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-08 11:38:56 +0100 |
commit | 347ad308f8223d966793f0421c72432f7e912377 (patch) | |
tree | 2b7319dc37787ce2828101c451987d086dd47360 /pipelines | |
parent | Fixed Lora training (diff) | |
download | textual-inversion-diff-347ad308f8223d966793f0421c72432f7e912377.tar.gz textual-inversion-diff-347ad308f8223d966793f0421c72432f7e912377.tar.bz2 textual-inversion-diff-347ad308f8223d966793f0421c72432f7e912377.zip |
Integrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 169 |
1 files changed, 162 insertions, 7 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index dab7878..66566b0 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -1,9 +1,11 @@ | |||
1 | import inspect | 1 | import inspect |
2 | import warnings | 2 | import warnings |
3 | import math | ||
3 | from typing import List, Dict, Any, Optional, Union, Callable | 4 | from typing import List, Dict, Any, Optional, Union, Callable |
4 | 5 | ||
5 | import numpy as np | 6 | import numpy as np |
6 | import torch | 7 | import torch |
8 | import torchvision.transforms as T | ||
7 | import PIL | 9 | import PIL |
8 | 10 | ||
9 | from diffusers.configuration_utils import FrozenDict | 11 | from diffusers.configuration_utils import FrozenDict |
@@ -37,6 +39,35 @@ def preprocess(image): | |||
37 | return 2.0 * image - 1.0 | 39 | return 2.0 * image - 1.0 |
38 | 40 | ||
39 | 41 | ||
42 | class CrossAttnStoreProcessor: | ||
43 | def __init__(self): | ||
44 | self.attention_probs = None | ||
45 | |||
46 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
47 | batch_size, sequence_length, _ = hidden_states.shape | ||
48 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
49 | |||
50 | query = attn.to_q(hidden_states) | ||
51 | query = attn.head_to_batch_dim(query) | ||
52 | |||
53 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
54 | key = attn.to_k(encoder_hidden_states) | ||
55 | value = attn.to_v(encoder_hidden_states) | ||
56 | key = attn.head_to_batch_dim(key) | ||
57 | value = attn.head_to_batch_dim(value) | ||
58 | |||
59 | self.attention_probs = attn.get_attention_scores(query, key, attention_mask) | ||
60 | hidden_states = torch.bmm(self.attention_probs, value) | ||
61 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
62 | |||
63 | # linear proj | ||
64 | hidden_states = attn.to_out[0](hidden_states) | ||
65 | # dropout | ||
66 | hidden_states = attn.to_out[1](hidden_states) | ||
67 | |||
68 | return hidden_states | ||
69 | |||
70 | |||
40 | class VlpnStableDiffusion(DiffusionPipeline): | 71 | class VlpnStableDiffusion(DiffusionPipeline): |
41 | def __init__( | 72 | def __init__( |
42 | self, | 73 | self, |
@@ -233,9 +264,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
233 | else: | 264 | else: |
234 | attention_mask = None | 265 | attention_mask = None |
235 | 266 | ||
236 | text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) | 267 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) |
237 | 268 | ||
238 | return text_embeddings | 269 | return prompt_embeds |
239 | 270 | ||
240 | def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): | 271 | def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): |
241 | if latents_are_image: | 272 | if latents_are_image: |
@@ -330,6 +361,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
330 | width: Optional[int] = None, | 361 | width: Optional[int] = None, |
331 | num_inference_steps: int = 50, | 362 | num_inference_steps: int = 50, |
332 | guidance_scale: float = 7.5, | 363 | guidance_scale: float = 7.5, |
364 | sag_scale: float = 0.75, | ||
333 | eta: float = 0.0, | 365 | eta: float = 0.0, |
334 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 366 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
335 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 367 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
@@ -403,10 +435,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
403 | batch_size = len(prompt) | 435 | batch_size = len(prompt) |
404 | device = self.execution_device | 436 | device = self.execution_device |
405 | do_classifier_free_guidance = guidance_scale > 1.0 | 437 | do_classifier_free_guidance = guidance_scale > 1.0 |
438 | do_self_attention_guidance = sag_scale > 0.0 | ||
406 | latents_are_image = isinstance(image, PIL.Image.Image) | 439 | latents_are_image = isinstance(image, PIL.Image.Image) |
407 | 440 | ||
408 | # 3. Encode input prompt | 441 | # 3. Encode input prompt |
409 | text_embeddings = self.encode_prompt( | 442 | prompt_embeds = self.encode_prompt( |
410 | prompt, | 443 | prompt, |
411 | negative_prompt, | 444 | negative_prompt, |
412 | num_images_per_prompt, | 445 | num_images_per_prompt, |
@@ -427,7 +460,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
427 | image, | 460 | image, |
428 | latent_timestep, | 461 | latent_timestep, |
429 | batch_size * num_images_per_prompt, | 462 | batch_size * num_images_per_prompt, |
430 | text_embeddings.dtype, | 463 | prompt_embeds.dtype, |
431 | device, | 464 | device, |
432 | generator | 465 | generator |
433 | ) | 466 | ) |
@@ -437,7 +470,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
437 | num_channels_latents, | 470 | num_channels_latents, |
438 | height, | 471 | height, |
439 | width, | 472 | width, |
440 | text_embeddings.dtype, | 473 | prompt_embeds.dtype, |
441 | device, | 474 | device, |
442 | generator, | 475 | generator, |
443 | image, | 476 | image, |
@@ -446,7 +479,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
446 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 479 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
447 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | 480 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
448 | 481 | ||
449 | # 7. Denoising loop | 482 | # 7. Denoising loo |
483 | if do_self_attention_guidance: | ||
484 | store_processor = CrossAttnStoreProcessor() | ||
485 | self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor | ||
486 | |||
450 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | 487 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
451 | with self.progress_bar(total=num_inference_steps) as progress_bar: | 488 | with self.progress_bar(total=num_inference_steps) as progress_bar: |
452 | for i, t in enumerate(timesteps): | 489 | for i, t in enumerate(timesteps): |
@@ -458,7 +495,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
458 | noise_pred = self.unet( | 495 | noise_pred = self.unet( |
459 | latent_model_input, | 496 | latent_model_input, |
460 | t, | 497 | t, |
461 | encoder_hidden_states=text_embeddings, | 498 | encoder_hidden_states=prompt_embeds, |
462 | cross_attention_kwargs=cross_attention_kwargs, | 499 | cross_attention_kwargs=cross_attention_kwargs, |
463 | ).sample | 500 | ).sample |
464 | 501 | ||
@@ -467,6 +504,36 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
467 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 504 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
468 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 505 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
469 | 506 | ||
507 | if do_self_attention_guidance: | ||
508 | # classifier-free guidance produces two chunks of attention map | ||
509 | # and we only use unconditional one according to equation (24) | ||
510 | # in https://arxiv.org/pdf/2210.00939.pdf | ||
511 | if do_classifier_free_guidance: | ||
512 | # DDIM-like prediction of x0 | ||
513 | pred_x0 = self.pred_x0_from_eps(latents, noise_pred_uncond, t) | ||
514 | # get the stored attention maps | ||
515 | uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) | ||
516 | # self-attention-based degrading of latents | ||
517 | degraded_latents = self.sag_masking( | ||
518 | pred_x0, uncond_attn, t, self.pred_eps_from_noise(latents, noise_pred_uncond, t) | ||
519 | ) | ||
520 | uncond_emb, _ = prompt_embeds.chunk(2) | ||
521 | # forward and give guidance | ||
522 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample | ||
523 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) | ||
524 | else: | ||
525 | # DDIM-like prediction of x0 | ||
526 | pred_x0 = self.pred_x0_from_eps(latents, noise_pred, t) | ||
527 | # get the stored attention maps | ||
528 | cond_attn = store_processor.attention_probs | ||
529 | # self-attention-based degrading of latents | ||
530 | degraded_latents = self.sag_masking( | ||
531 | pred_x0, cond_attn, t, self.pred_eps_from_noise(latents, noise_pred, t) | ||
532 | ) | ||
533 | # forward and give guidance | ||
534 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample | ||
535 | noise_pred += sag_scale * (noise_pred - degraded_pred) | ||
536 | |||
470 | # compute the previous noisy sample x_t -> x_t-1 | 537 | # compute the previous noisy sample x_t -> x_t-1 |
471 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 538 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
472 | 539 | ||
@@ -490,3 +557,91 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
490 | return (image, has_nsfw_concept) | 557 | return (image, has_nsfw_concept) |
491 | 558 | ||
492 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | 559 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
560 | |||
561 | # Self-Attention-Guided (SAG) Stable Diffusion | ||
562 | |||
563 | def sag_masking(self, original_latents, attn_map, t, eps): | ||
564 | # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf | ||
565 | bh, hw1, hw2 = attn_map.shape | ||
566 | b, latent_channel, latent_h, latent_w = original_latents.shape | ||
567 | h = self.unet.attention_head_dim | ||
568 | if isinstance(h, list): | ||
569 | h = h[-1] | ||
570 | map_size = math.isqrt(hw1) | ||
571 | |||
572 | # Produce attention mask | ||
573 | attn_map = attn_map.reshape(b, h, hw1, hw2) | ||
574 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 | ||
575 | attn_mask = ( | ||
576 | attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) | ||
577 | ) | ||
578 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) | ||
579 | |||
580 | # Blur according to the self-attention mask | ||
581 | transform = T.GaussianBlur(kernel_size=9, sigma=1.0) | ||
582 | degraded_latents = transform(original_latents) | ||
583 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) | ||
584 | |||
585 | # Noise it again to match the noise level | ||
586 | degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) | ||
587 | |||
588 | return degraded_latents | ||
589 | |||
590 | # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step | ||
591 | def pred_x0_from_eps(self, sample, model_output, timestep): | ||
592 | # 1. get previous step value (=t-1) | ||
593 | # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | ||
594 | |||
595 | # 2. compute alphas, betas | ||
596 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | ||
597 | # alpha_prod_t_prev = ( | ||
598 | # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | ||
599 | # ) | ||
600 | |||
601 | beta_prod_t = 1 - alpha_prod_t | ||
602 | # 3. compute predicted original sample from predicted noise also called | ||
603 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||
604 | if self.scheduler.config.prediction_type == "epsilon": | ||
605 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | ||
606 | elif self.scheduler.config.prediction_type == "sample": | ||
607 | pred_original_sample = model_output | ||
608 | elif self.scheduler.config.prediction_type == "v_prediction": | ||
609 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | ||
610 | # predict V | ||
611 | model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | ||
612 | else: | ||
613 | raise ValueError( | ||
614 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | ||
615 | " or `v_prediction`" | ||
616 | ) | ||
617 | # # 4. Clip "predicted x_0" | ||
618 | # if self.scheduler.config.clip_sample: | ||
619 | # pred_original_sample = torch.clamp(pred_original_sample, -1, 1) | ||
620 | |||
621 | return pred_original_sample | ||
622 | |||
623 | def pred_eps_from_noise(self, sample, model_output, timestep): | ||
624 | # 1. get previous step value (=t-1) | ||
625 | # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | ||
626 | |||
627 | # 2. compute alphas, betas | ||
628 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | ||
629 | # alpha_prod_t_prev = ( | ||
630 | # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | ||
631 | # ) | ||
632 | |||
633 | beta_prod_t = 1 - alpha_prod_t | ||
634 | # 3. compute predicted eps from model output | ||
635 | if self.scheduler.config.prediction_type == "epsilon": | ||
636 | pred_eps = model_output | ||
637 | elif self.scheduler.config.prediction_type == "sample": | ||
638 | pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) | ||
639 | elif self.scheduler.config.prediction_type == "v_prediction": | ||
640 | pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output | ||
641 | else: | ||
642 | raise ValueError( | ||
643 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | ||
644 | " or `v_prediction`" | ||
645 | ) | ||
646 | |||
647 | return pred_eps | ||