From 94b676d91382267e7429bd68362019868affd9d1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 13 Feb 2023 17:19:18 +0100 Subject: Update --- .../stable_diffusion/vlpn_stable_diffusion.py | 69 ++++++++++++---------- 1 file changed, 37 insertions(+), 32 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 66566b0..cb09fe1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional, Union, Callable import numpy as np import torch -import torchvision.transforms as T +import torch.nn.functional as F import PIL from diffusers.configuration_utils import FrozenDict @@ -39,6 +39,27 @@ def preprocess(image): return 2.0 * image - 1.0 +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + + return img + + class CrossAttnStoreProcessor: def __init__(self): self.attention_probs = None @@ -46,13 +67,17 @@ class CrossAttnStoreProcessor: def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) @@ -510,12 +535,12 @@ class VlpnStableDiffusion(DiffusionPipeline): # in https://arxiv.org/pdf/2210.00939.pdf if do_classifier_free_guidance: # DDIM-like prediction of x0 - pred_x0 = self.pred_x0_from_eps(latents, noise_pred_uncond, t) + pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) # get the stored attention maps uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) # self-attention-based degrading of latents degraded_latents = self.sag_masking( - pred_x0, uncond_attn, t, self.pred_eps_from_noise(latents, noise_pred_uncond, t) + pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) ) uncond_emb, _ = prompt_embeds.chunk(2) # forward and give guidance @@ -523,12 +548,12 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) else: # DDIM-like prediction of x0 - pred_x0 = self.pred_x0_from_eps(latents, noise_pred, t) + pred_x0 = self.pred_x0(latents, noise_pred, t) # get the stored attention maps cond_attn = store_processor.attention_probs # self-attention-based degrading of latents degraded_latents = self.sag_masking( - pred_x0, cond_attn, t, self.pred_eps_from_noise(latents, noise_pred, t) + pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) ) # forward and give guidance degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample @@ -578,8 +603,7 @@ class VlpnStableDiffusion(DiffusionPipeline): attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) # Blur according to the self-attention mask - transform = T.GaussianBlur(kernel_size=9, sigma=1.0) - degraded_latents = transform(original_latents) + degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) # Noise it again to match the noise level @@ -588,19 +612,11 @@ class VlpnStableDiffusion(DiffusionPipeline): return degraded_latents # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step - def pred_x0_from_eps(self, sample, model_output, timestep): - # 1. get previous step value (=t-1) - # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps - - # 2. compute alphas, betas + # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) + def pred_x0(self, sample, model_output, timestep): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - # alpha_prod_t_prev = ( - # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod - # ) beta_prod_t = 1 - alpha_prod_t - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.scheduler.config.prediction_type == "sample": @@ -614,24 +630,13 @@ class VlpnStableDiffusion(DiffusionPipeline): f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," " or `v_prediction`" ) - # # 4. Clip "predicted x_0" - # if self.scheduler.config.clip_sample: - # pred_original_sample = torch.clamp(pred_original_sample, -1, 1) return pred_original_sample - def pred_eps_from_noise(self, sample, model_output, timestep): - # 1. get previous step value (=t-1) - # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps - - # 2. compute alphas, betas + def pred_epsilon(self, sample, model_output, timestep): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - # alpha_prod_t_prev = ( - # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod - # ) beta_prod_t = 1 - alpha_prod_t - # 3. compute predicted eps from model output if self.scheduler.config.prediction_type == "epsilon": pred_eps = model_output elif self.scheduler.config.prediction_type == "sample": -- cgit v1.2.3-70-g09d2