diff options
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 69 |
1 files changed, 37 insertions, 32 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 66566b0..cb09fe1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional, Union, Callable | |||
5 | 5 | ||
6 | import numpy as np | 6 | import numpy as np |
7 | import torch | 7 | import torch |
8 | import torchvision.transforms as T | 8 | import torch.nn.functional as F |
9 | import PIL | 9 | import PIL |
10 | 10 | ||
11 | from diffusers.configuration_utils import FrozenDict | 11 | from diffusers.configuration_utils import FrozenDict |
@@ -39,6 +39,27 @@ def preprocess(image): | |||
39 | return 2.0 * image - 1.0 | 39 | return 2.0 * image - 1.0 |
40 | 40 | ||
41 | 41 | ||
42 | def gaussian_blur_2d(img, kernel_size, sigma): | ||
43 | ksize_half = (kernel_size - 1) * 0.5 | ||
44 | |||
45 | x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) | ||
46 | |||
47 | pdf = torch.exp(-0.5 * (x / sigma).pow(2)) | ||
48 | |||
49 | x_kernel = pdf / pdf.sum() | ||
50 | x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) | ||
51 | |||
52 | kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) | ||
53 | kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) | ||
54 | |||
55 | padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] | ||
56 | |||
57 | img = F.pad(img, padding, mode="reflect") | ||
58 | img = F.conv2d(img, kernel2d, groups=img.shape[-3]) | ||
59 | |||
60 | return img | ||
61 | |||
62 | |||
42 | class CrossAttnStoreProcessor: | 63 | class CrossAttnStoreProcessor: |
43 | def __init__(self): | 64 | def __init__(self): |
44 | self.attention_probs = None | 65 | self.attention_probs = None |
@@ -46,13 +67,17 @@ class CrossAttnStoreProcessor: | |||
46 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): | 67 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): |
47 | batch_size, sequence_length, _ = hidden_states.shape | 68 | batch_size, sequence_length, _ = hidden_states.shape |
48 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | 69 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
49 | |||
50 | query = attn.to_q(hidden_states) | 70 | query = attn.to_q(hidden_states) |
51 | query = attn.head_to_batch_dim(query) | ||
52 | 71 | ||
53 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | 72 | if encoder_hidden_states is None: |
73 | encoder_hidden_states = hidden_states | ||
74 | elif attn.cross_attention_norm: | ||
75 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states) | ||
76 | |||
54 | key = attn.to_k(encoder_hidden_states) | 77 | key = attn.to_k(encoder_hidden_states) |
55 | value = attn.to_v(encoder_hidden_states) | 78 | value = attn.to_v(encoder_hidden_states) |
79 | |||
80 | query = attn.head_to_batch_dim(query) | ||
56 | key = attn.head_to_batch_dim(key) | 81 | key = attn.head_to_batch_dim(key) |
57 | value = attn.head_to_batch_dim(value) | 82 | value = attn.head_to_batch_dim(value) |
58 | 83 | ||
@@ -510,12 +535,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
510 | # in https://arxiv.org/pdf/2210.00939.pdf | 535 | # in https://arxiv.org/pdf/2210.00939.pdf |
511 | if do_classifier_free_guidance: | 536 | if do_classifier_free_guidance: |
512 | # DDIM-like prediction of x0 | 537 | # DDIM-like prediction of x0 |
513 | pred_x0 = self.pred_x0_from_eps(latents, noise_pred_uncond, t) | 538 | pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) |
514 | # get the stored attention maps | 539 | # get the stored attention maps |
515 | uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) | 540 | uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) |
516 | # self-attention-based degrading of latents | 541 | # self-attention-based degrading of latents |
517 | degraded_latents = self.sag_masking( | 542 | degraded_latents = self.sag_masking( |
518 | pred_x0, uncond_attn, t, self.pred_eps_from_noise(latents, noise_pred_uncond, t) | 543 | pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) |
519 | ) | 544 | ) |
520 | uncond_emb, _ = prompt_embeds.chunk(2) | 545 | uncond_emb, _ = prompt_embeds.chunk(2) |
521 | # forward and give guidance | 546 | # forward and give guidance |
@@ -523,12 +548,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
523 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) | 548 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) |
524 | else: | 549 | else: |
525 | # DDIM-like prediction of x0 | 550 | # DDIM-like prediction of x0 |
526 | pred_x0 = self.pred_x0_from_eps(latents, noise_pred, t) | 551 | pred_x0 = self.pred_x0(latents, noise_pred, t) |
527 | # get the stored attention maps | 552 | # get the stored attention maps |
528 | cond_attn = store_processor.attention_probs | 553 | cond_attn = store_processor.attention_probs |
529 | # self-attention-based degrading of latents | 554 | # self-attention-based degrading of latents |
530 | degraded_latents = self.sag_masking( | 555 | degraded_latents = self.sag_masking( |
531 | pred_x0, cond_attn, t, self.pred_eps_from_noise(latents, noise_pred, t) | 556 | pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) |
532 | ) | 557 | ) |
533 | # forward and give guidance | 558 | # forward and give guidance |
534 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample | 559 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample |
@@ -578,8 +603,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
578 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) | 603 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) |
579 | 604 | ||
580 | # Blur according to the self-attention mask | 605 | # Blur according to the self-attention mask |
581 | transform = T.GaussianBlur(kernel_size=9, sigma=1.0) | 606 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) |
582 | degraded_latents = transform(original_latents) | ||
583 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) | 607 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) |
584 | 608 | ||
585 | # Noise it again to match the noise level | 609 | # Noise it again to match the noise level |
@@ -588,19 +612,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
588 | return degraded_latents | 612 | return degraded_latents |
589 | 613 | ||
590 | # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step | 614 | # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step |
591 | def pred_x0_from_eps(self, sample, model_output, timestep): | 615 | # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) |
592 | # 1. get previous step value (=t-1) | 616 | def pred_x0(self, sample, model_output, timestep): |
593 | # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | ||
594 | |||
595 | # 2. compute alphas, betas | ||
596 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | 617 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] |
597 | # alpha_prod_t_prev = ( | ||
598 | # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | ||
599 | # ) | ||
600 | 618 | ||
601 | beta_prod_t = 1 - alpha_prod_t | 619 | beta_prod_t = 1 - alpha_prod_t |
602 | # 3. compute predicted original sample from predicted noise also called | ||
603 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||
604 | if self.scheduler.config.prediction_type == "epsilon": | 620 | if self.scheduler.config.prediction_type == "epsilon": |
605 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | 621 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
606 | elif self.scheduler.config.prediction_type == "sample": | 622 | elif self.scheduler.config.prediction_type == "sample": |
@@ -614,24 +630,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
614 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | 630 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," |
615 | " or `v_prediction`" | 631 | " or `v_prediction`" |
616 | ) | 632 | ) |
617 | # # 4. Clip "predicted x_0" | ||
618 | # if self.scheduler.config.clip_sample: | ||
619 | # pred_original_sample = torch.clamp(pred_original_sample, -1, 1) | ||
620 | 633 | ||
621 | return pred_original_sample | 634 | return pred_original_sample |
622 | 635 | ||
623 | def pred_eps_from_noise(self, sample, model_output, timestep): | 636 | def pred_epsilon(self, sample, model_output, timestep): |
624 | # 1. get previous step value (=t-1) | ||
625 | # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | ||
626 | |||
627 | # 2. compute alphas, betas | ||
628 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | 637 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] |
629 | # alpha_prod_t_prev = ( | ||
630 | # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | ||
631 | # ) | ||
632 | 638 | ||
633 | beta_prod_t = 1 - alpha_prod_t | 639 | beta_prod_t = 1 - alpha_prod_t |
634 | # 3. compute predicted eps from model output | ||
635 | if self.scheduler.config.prediction_type == "epsilon": | 640 | if self.scheduler.config.prediction_type == "epsilon": |
636 | pred_eps = model_output | 641 | pred_eps = model_output |
637 | elif self.scheduler.config.prediction_type == "sample": | 642 | elif self.scheduler.config.prediction_type == "sample": |