From 347ad308f8223d966793f0421c72432f7e912377 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 8 Feb 2023 11:38:56 +0100 Subject: Integrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline --- .../stable_diffusion/vlpn_stable_diffusion.py | 169 ++++++++++++++++++++- train_dreambooth.py | 2 +- train_lora.py | 8 - train_ti.py | 2 +- 4 files changed, 164 insertions(+), 17 deletions(-) diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index dab7878..66566b0 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -1,9 +1,11 @@ import inspect import warnings +import math from typing import List, Dict, Any, Optional, Union, Callable import numpy as np import torch +import torchvision.transforms as T import PIL from diffusers.configuration_utils import FrozenDict @@ -37,6 +39,35 @@ def preprocess(image): return 2.0 * image - 1.0 +class CrossAttnStoreProcessor: + def __init__(self): + self.attention_probs = None + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + self.attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(self.attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + class VlpnStableDiffusion(DiffusionPipeline): def __init__( self, @@ -233,9 +264,9 @@ class VlpnStableDiffusion(DiffusionPipeline): else: attention_mask = None - text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) + prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) - return text_embeddings + return prompt_embeds def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): if latents_are_image: @@ -330,6 +361,7 @@ class VlpnStableDiffusion(DiffusionPipeline): width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, + sag_scale: float = 0.75, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, @@ -403,10 +435,11 @@ class VlpnStableDiffusion(DiffusionPipeline): batch_size = len(prompt) device = self.execution_device do_classifier_free_guidance = guidance_scale > 1.0 + do_self_attention_guidance = sag_scale > 0.0 latents_are_image = isinstance(image, PIL.Image.Image) # 3. Encode input prompt - text_embeddings = self.encode_prompt( + prompt_embeds = self.encode_prompt( prompt, negative_prompt, num_images_per_prompt, @@ -427,7 +460,7 @@ class VlpnStableDiffusion(DiffusionPipeline): image, latent_timestep, batch_size * num_images_per_prompt, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator ) @@ -437,7 +470,7 @@ class VlpnStableDiffusion(DiffusionPipeline): num_channels_latents, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, image, @@ -446,7 +479,11 @@ class VlpnStableDiffusion(DiffusionPipeline): # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Denoising loop + # 7. Denoising loo + if do_self_attention_guidance: + store_processor = CrossAttnStoreProcessor() + self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -458,7 +495,7 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred = self.unet( latent_model_input, t, - encoder_hidden_states=text_embeddings, + encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample @@ -467,6 +504,36 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_self_attention_guidance: + # classifier-free guidance produces two chunks of attention map + # and we only use unconditional one according to equation (24) + # in https://arxiv.org/pdf/2210.00939.pdf + if do_classifier_free_guidance: + # DDIM-like prediction of x0 + pred_x0 = self.pred_x0_from_eps(latents, noise_pred_uncond, t) + # get the stored attention maps + uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) + # self-attention-based degrading of latents + degraded_latents = self.sag_masking( + pred_x0, uncond_attn, t, self.pred_eps_from_noise(latents, noise_pred_uncond, t) + ) + uncond_emb, _ = prompt_embeds.chunk(2) + # forward and give guidance + degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample + noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) + else: + # DDIM-like prediction of x0 + pred_x0 = self.pred_x0_from_eps(latents, noise_pred, t) + # get the stored attention maps + cond_attn = store_processor.attention_probs + # self-attention-based degrading of latents + degraded_latents = self.sag_masking( + pred_x0, cond_attn, t, self.pred_eps_from_noise(latents, noise_pred, t) + ) + # forward and give guidance + degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample + noise_pred += sag_scale * (noise_pred - degraded_pred) + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample @@ -490,3 +557,91 @@ class VlpnStableDiffusion(DiffusionPipeline): return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + # Self-Attention-Guided (SAG) Stable Diffusion + + def sag_masking(self, original_latents, attn_map, t, eps): + # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf + bh, hw1, hw2 = attn_map.shape + b, latent_channel, latent_h, latent_w = original_latents.shape + h = self.unet.attention_head_dim + if isinstance(h, list): + h = h[-1] + map_size = math.isqrt(hw1) + + # Produce attention mask + attn_map = attn_map.reshape(b, h, hw1, hw2) + attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 + attn_mask = ( + attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) + ) + attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) + + # Blur according to the self-attention mask + transform = T.GaussianBlur(kernel_size=9, sigma=1.0) + degraded_latents = transform(original_latents) + degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) + + # Noise it again to match the noise level + degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) + + return degraded_latents + + # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step + def pred_x0_from_eps(self, sample, model_output, timestep): + # 1. get previous step value (=t-1) + # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + # alpha_prod_t_prev = ( + # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod + # ) + + beta_prod_t = 1 - alpha_prod_t + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.scheduler.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + # # 4. Clip "predicted x_0" + # if self.scheduler.config.clip_sample: + # pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + return pred_original_sample + + def pred_eps_from_noise(self, sample, model_output, timestep): + # 1. get previous step value (=t-1) + # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + # alpha_prod_t_prev = ( + # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod + # ) + + beta_prod_t = 1 - alpha_prod_t + # 3. compute predicted eps from model output + if self.scheduler.config.prediction_type == "epsilon": + pred_eps = model_output + elif self.scheduler.config.prediction_type == "sample": + pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) + elif self.scheduler.config.prediction_type == "v_prediction": + pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + + return pred_eps diff --git a/train_dreambooth.py b/train_dreambooth.py index a29c507..8ac70e8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -304,7 +304,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=1e-2, + default=0, help="Weight decay to use." ) parser.add_argument( diff --git a/train_lora.py b/train_lora.py index ab1753b..5fd05cc 100644 --- a/train_lora.py +++ b/train_lora.py @@ -177,11 +177,6 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) parser.add_argument( "--find_lr", action="store_true", @@ -429,9 +424,6 @@ def main(): vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - unet.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) diff --git a/train_ti.py b/train_ti.py index 2840def..c79dfa2 100644 --- a/train_ti.py +++ b/train_ti.py @@ -143,7 +143,7 @@ def parse_args(): parser.add_argument( "--num_buckets", type=int, - default=0, + default=4, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( -- cgit v1.2.3-70-g09d2