summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-13 17:19:18 +0100
committerVolpeon <git@volpeon.ink>2023-02-13 17:19:18 +0100
commit94b676d91382267e7429bd68362019868affd9d1 (patch)
tree513697739ab25217cbfcff630299d02b1f6e98c8 /pipelines/stable_diffusion
parentIntegrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline (diff)
downloadtextual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.gz
textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.bz2
textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.zip
Update
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py69
1 files changed, 37 insertions, 32 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 66566b0..cb09fe1 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional, Union, Callable
5 5
6import numpy as np 6import numpy as np
7import torch 7import torch
8import torchvision.transforms as T 8import torch.nn.functional as F
9import PIL 9import PIL
10 10
11from diffusers.configuration_utils import FrozenDict 11from diffusers.configuration_utils import FrozenDict
@@ -39,6 +39,27 @@ def preprocess(image):
39 return 2.0 * image - 1.0 39 return 2.0 * image - 1.0
40 40
41 41
42def gaussian_blur_2d(img, kernel_size, sigma):
43 ksize_half = (kernel_size - 1) * 0.5
44
45 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
46
47 pdf = torch.exp(-0.5 * (x / sigma).pow(2))
48
49 x_kernel = pdf / pdf.sum()
50 x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
51
52 kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
53 kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
54
55 padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
56
57 img = F.pad(img, padding, mode="reflect")
58 img = F.conv2d(img, kernel2d, groups=img.shape[-3])
59
60 return img
61
62
42class CrossAttnStoreProcessor: 63class CrossAttnStoreProcessor:
43 def __init__(self): 64 def __init__(self):
44 self.attention_probs = None 65 self.attention_probs = None
@@ -46,13 +67,17 @@ class CrossAttnStoreProcessor:
46 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): 67 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
47 batch_size, sequence_length, _ = hidden_states.shape 68 batch_size, sequence_length, _ = hidden_states.shape
48 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 69 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
49
50 query = attn.to_q(hidden_states) 70 query = attn.to_q(hidden_states)
51 query = attn.head_to_batch_dim(query)
52 71
53 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 72 if encoder_hidden_states is None:
73 encoder_hidden_states = hidden_states
74 elif attn.cross_attention_norm:
75 encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
76
54 key = attn.to_k(encoder_hidden_states) 77 key = attn.to_k(encoder_hidden_states)
55 value = attn.to_v(encoder_hidden_states) 78 value = attn.to_v(encoder_hidden_states)
79
80 query = attn.head_to_batch_dim(query)
56 key = attn.head_to_batch_dim(key) 81 key = attn.head_to_batch_dim(key)
57 value = attn.head_to_batch_dim(value) 82 value = attn.head_to_batch_dim(value)
58 83
@@ -510,12 +535,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
510 # in https://arxiv.org/pdf/2210.00939.pdf 535 # in https://arxiv.org/pdf/2210.00939.pdf
511 if do_classifier_free_guidance: 536 if do_classifier_free_guidance:
512 # DDIM-like prediction of x0 537 # DDIM-like prediction of x0
513 pred_x0 = self.pred_x0_from_eps(latents, noise_pred_uncond, t) 538 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
514 # get the stored attention maps 539 # get the stored attention maps
515 uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) 540 uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
516 # self-attention-based degrading of latents 541 # self-attention-based degrading of latents
517 degraded_latents = self.sag_masking( 542 degraded_latents = self.sag_masking(
518 pred_x0, uncond_attn, t, self.pred_eps_from_noise(latents, noise_pred_uncond, t) 543 pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t)
519 ) 544 )
520 uncond_emb, _ = prompt_embeds.chunk(2) 545 uncond_emb, _ = prompt_embeds.chunk(2)
521 # forward and give guidance 546 # forward and give guidance
@@ -523,12 +548,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
523 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) 548 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
524 else: 549 else:
525 # DDIM-like prediction of x0 550 # DDIM-like prediction of x0
526 pred_x0 = self.pred_x0_from_eps(latents, noise_pred, t) 551 pred_x0 = self.pred_x0(latents, noise_pred, t)
527 # get the stored attention maps 552 # get the stored attention maps
528 cond_attn = store_processor.attention_probs 553 cond_attn = store_processor.attention_probs
529 # self-attention-based degrading of latents 554 # self-attention-based degrading of latents
530 degraded_latents = self.sag_masking( 555 degraded_latents = self.sag_masking(
531 pred_x0, cond_attn, t, self.pred_eps_from_noise(latents, noise_pred, t) 556 pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t)
532 ) 557 )
533 # forward and give guidance 558 # forward and give guidance
534 degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample 559 degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
@@ -578,8 +603,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
578 attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) 603 attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w))
579 604
580 # Blur according to the self-attention mask 605 # Blur according to the self-attention mask
581 transform = T.GaussianBlur(kernel_size=9, sigma=1.0) 606 degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0)
582 degraded_latents = transform(original_latents)
583 degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) 607 degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask)
584 608
585 # Noise it again to match the noise level 609 # Noise it again to match the noise level
@@ -588,19 +612,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
588 return degraded_latents 612 return degraded_latents
589 613
590 # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step 614 # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step
591 def pred_x0_from_eps(self, sample, model_output, timestep): 615 # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.)
592 # 1. get previous step value (=t-1) 616 def pred_x0(self, sample, model_output, timestep):
593 # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
594
595 # 2. compute alphas, betas
596 alpha_prod_t = self.scheduler.alphas_cumprod[timestep] 617 alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
597 # alpha_prod_t_prev = (
598 # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
599 # )
600 618
601 beta_prod_t = 1 - alpha_prod_t 619 beta_prod_t = 1 - alpha_prod_t
602 # 3. compute predicted original sample from predicted noise also called
603 # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
604 if self.scheduler.config.prediction_type == "epsilon": 620 if self.scheduler.config.prediction_type == "epsilon":
605 pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 621 pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
606 elif self.scheduler.config.prediction_type == "sample": 622 elif self.scheduler.config.prediction_type == "sample":
@@ -614,24 +630,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
614 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," 630 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
615 " or `v_prediction`" 631 " or `v_prediction`"
616 ) 632 )
617 # # 4. Clip "predicted x_0"
618 # if self.scheduler.config.clip_sample:
619 # pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
620 633
621 return pred_original_sample 634 return pred_original_sample
622 635
623 def pred_eps_from_noise(self, sample, model_output, timestep): 636 def pred_epsilon(self, sample, model_output, timestep):
624 # 1. get previous step value (=t-1)
625 # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
626
627 # 2. compute alphas, betas
628 alpha_prod_t = self.scheduler.alphas_cumprod[timestep] 637 alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
629 # alpha_prod_t_prev = (
630 # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
631 # )
632 638
633 beta_prod_t = 1 - alpha_prod_t 639 beta_prod_t = 1 - alpha_prod_t
634 # 3. compute predicted eps from model output
635 if self.scheduler.config.prediction_type == "epsilon": 640 if self.scheduler.config.prediction_type == "epsilon":
636 pred_eps = model_output 641 pred_eps = model_output
637 elif self.scheduler.config.prediction_type == "sample": 642 elif self.scheduler.config.prediction_type == "sample":