From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- .../stable_diffusion/vlpn_stable_diffusion.py | 262 +++++++++++++++------ 1 file changed, 188 insertions(+), 74 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index aa446ec..16b8456 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -21,7 +21,9 @@ from diffusers import ( LMSDiscreteScheduler, PNDMScheduler, ) -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + StableDiffusionPipelineOutput, +) from diffusers.utils import logging, randn_tensor from transformers import CLIPTextModel, CLIPTokenizer @@ -62,13 +64,35 @@ def gaussian_blur_2d(img, kernel_size, sigma): return img +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) + return noise_cfg + + class CrossAttnStoreProcessor: def __init__(self): self.attention_probs = None - def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None + ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -113,7 +137,10 @@ class VlpnStableDiffusion(DiffusionPipeline): ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): warnings.warn( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -179,7 +206,12 @@ class VlpnStableDiffusion(DiffusionPipeline): device = torch.device("cuda") - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + for cpu_offloaded_model in [ + self.unet, + self.text_encoder, + self.vae, + self.safety_checker, + ]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @@ -223,35 +255,47 @@ class VlpnStableDiffusion(DiffusionPipeline): width: int, height: int, strength: float, - callback_steps: Optional[int] + callback_steps: Optional[int], ): - if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): + if isinstance(prompt, str) or ( + isinstance(prompt, list) and isinstance(prompt[0], int) + ): prompt = [prompt] if negative_prompt is None: negative_prompt = "" - if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): + if isinstance(negative_prompt, str) or ( + isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int) + ): negative_prompt = [negative_prompt] * len(prompt) if not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) if not isinstance(negative_prompt, list): - raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + raise ValueError( + f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}" + ) if len(negative_prompt) != len(prompt): raise ValueError( - f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") + f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}" + ) if strength < 0 or strength > 1: raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" @@ -266,7 +310,7 @@ class VlpnStableDiffusion(DiffusionPipeline): negative_prompt: Union[List[str], List[List[int]]], num_images_per_prompt: int, do_classifier_free_guidance: bool, - device + device, ): if isinstance(prompt[0], str): text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids @@ -277,7 +321,9 @@ class VlpnStableDiffusion(DiffusionPipeline): if do_classifier_free_guidance: if isinstance(prompt[0], str): - unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids + unconditional_input_ids = self.tokenizer( + negative_prompt, padding="do_not_pad" + ).input_ids else: unconditional_input_ids = negative_prompt unconditional_input_ids *= num_images_per_prompt @@ -286,12 +332,17 @@ class VlpnStableDiffusion(DiffusionPipeline): text_inputs = unify_input_ids(self.tokenizer, text_input_ids) text_input_ids = text_inputs.input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None - prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) + prompt_embeds = get_extended_embeddings( + self.text_encoder, text_input_ids.to(device), attention_mask + ) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) return prompt_embeds @@ -301,25 +352,21 @@ class VlpnStableDiffusion(DiffusionPipeline): init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = timesteps.to(device) return timesteps, num_inference_steps - t_start - def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): - offset_image = perlin_noise( - (batch_size, 1, width, height), - res=1, - generator=generator, - dtype=dtype, - device=device - ) - offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) - offset_latents = self.vae.config.scaling_factor * offset_latents - return offset_latents - - def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None): + def prepare_latents_from_image( + self, + init_image, + timestep, + batch_size, + dtype, + device, + generator=None, + ): init_image = init_image.to(device=device, dtype=dtype) latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) latents = self.vae.config.scaling_factor * latents @@ -333,20 +380,32 @@ class VlpnStableDiffusion(DiffusionPipeline): latents = torch.cat([latents] * batch_multiplier, dim=0) # add noise to latents using the timesteps - noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) - - if brightness_offset != 0: - noise += brightness_offset * self.prepare_brightness_offset( - batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator - ) + noise = torch.randn( + latents.shape, generator=generator, device=device, dtype=dtype + ) # get latents latents = self.scheduler.add_noise(latents, noise, timestep) return latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -354,15 +413,12 @@ class VlpnStableDiffusion(DiffusionPipeline): ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) else: latents = latents.to(device) - if brightness_offset != 0: - latents += brightness_offset * self.prepare_brightness_offset( - batch_size, height, width, dtype, device, generator - ) - # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @@ -373,13 +429,17 @@ class VlpnStableDiffusion(DiffusionPipeline): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -396,7 +456,9 @@ class VlpnStableDiffusion(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str], List[int], List[List[int]]], - negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, + negative_prompt: Optional[ + Union[str, List[str], List[int], List[List[int]]] + ] = None, num_images_per_prompt: int = 1, strength: float = 1.0, height: Optional[int] = None, @@ -407,12 +469,12 @@ class VlpnStableDiffusion(DiffusionPipeline): eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, - brightness_offset: Union[float, torch.FloatTensor] = 0, output_type: str = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -472,7 +534,9 @@ class VlpnStableDiffusion(DiffusionPipeline): width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct - prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) + prompt, negative_prompt = self.check_inputs( + prompt, negative_prompt, width, height, strength, callback_steps + ) # 2. Define call parameters batch_size = len(prompt) @@ -488,7 +552,7 @@ class VlpnStableDiffusion(DiffusionPipeline): negative_prompt, num_images_per_prompt, do_classifier_free_guidance, - device + device, ) # 4. Prepare latent variables @@ -497,7 +561,9 @@ class VlpnStableDiffusion(DiffusionPipeline): # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device + ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -506,7 +572,6 @@ class VlpnStableDiffusion(DiffusionPipeline): image, latent_timestep, batch_size * num_images_per_prompt, - brightness_offset, prompt_embeds.dtype, device, generator, @@ -517,7 +582,6 @@ class VlpnStableDiffusion(DiffusionPipeline): num_channels_latents, height, width, - brightness_offset, prompt_embeds.dtype, device, generator, @@ -530,14 +594,20 @@ class VlpnStableDiffusion(DiffusionPipeline): # 8. Denoising loo if do_self_attention_guidance: store_processor = CrossAttnStoreProcessor() - self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor + self.unet.mid_block.attentions[0].transformer_blocks[ + 0 + ].attn1.processor = store_processor num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual noise_pred = self.unet( @@ -551,7 +621,12 @@ class VlpnStableDiffusion(DiffusionPipeline): # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale + ) if do_self_attention_guidance: # classifier-free guidance produces two chunks of attention map @@ -561,15 +636,24 @@ class VlpnStableDiffusion(DiffusionPipeline): # DDIM-like prediction of x0 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) # get the stored attention maps - uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) + uncond_attn, cond_attn = store_processor.attention_probs.chunk( + 2 + ) # self-attention-based degrading of latents degraded_latents = self.sag_masking( - pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) + pred_x0, + uncond_attn, + t, + self.pred_epsilon(latents, noise_pred_uncond, t), ) uncond_emb, _ = prompt_embeds.chunk(2) # forward and give guidance degraded_pred = self.unet( - degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] + degraded_latents, + t, + encoder_hidden_states=uncond_emb, + return_dict=False, + )[0] noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) else: # DDIM-like prediction of x0 @@ -578,18 +662,29 @@ class VlpnStableDiffusion(DiffusionPipeline): cond_attn = store_processor.attention_probs # self-attention-based degrading of latents degraded_latents = self.sag_masking( - pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) + pred_x0, + cond_attn, + t, + self.pred_epsilon(latents, noise_pred, t), ) # forward and give guidance degraded_pred = self.unet( - degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] + degraded_latents, + t, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0] noise_pred += sag_scale * (noise_pred - degraded_pred) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) @@ -615,7 +710,9 @@ class VlpnStableDiffusion(DiffusionPipeline): if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) # Self-Attention-Guided (SAG) Stable Diffusion @@ -632,16 +729,23 @@ class VlpnStableDiffusion(DiffusionPipeline): attn_map = attn_map.reshape(b, h, hw1, hw2) attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 attn_mask = ( - attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) + attn_mask.reshape(b, map_size, map_size) + .unsqueeze(1) + .repeat(1, latent_channel, 1, 1) + .type(attn_map.dtype) ) attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) # Blur according to the self-attention mask degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) - degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) + degraded_latents = degraded_latents * attn_mask + original_latents * ( + 1 - attn_mask + ) # Noise it again to match the noise level - degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) + degraded_latents = self.scheduler.add_noise( + degraded_latents, noise=eps, timesteps=t + ) return degraded_latents @@ -652,13 +756,19 @@ class VlpnStableDiffusion(DiffusionPipeline): beta_prod_t = 1 - alpha_prod_t if self.scheduler.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_original_sample = ( + sample - beta_prod_t ** (0.5) * model_output + ) / alpha_prod_t ** (0.5) elif self.scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif self.scheduler.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_original_sample = (alpha_prod_t**0.5) * sample - ( + beta_prod_t**0.5 + ) * model_output # predict V - model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + model_output = (alpha_prod_t**0.5) * model_output + ( + beta_prod_t**0.5 + ) * sample else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," @@ -674,9 +784,13 @@ class VlpnStableDiffusion(DiffusionPipeline): if self.scheduler.config.prediction_type == "epsilon": pred_eps = model_output elif self.scheduler.config.prediction_type == "sample": - pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) + pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / ( + beta_prod_t**0.5 + ) elif self.scheduler.config.prediction_type == "v_prediction": - pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output + pred_eps = (beta_prod_t**0.5) * sample + ( + alpha_prod_t**0.5 + ) * model_output else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," -- cgit v1.2.3-70-g09d2