From 8ff51a771905d0d14a3c690f54eb644515730348 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 14 Nov 2022 18:41:38 +0100 Subject: Refactoring --- .../stable_diffusion/vlpn_stable_diffusion.py | 318 ++++++++++++++------- 1 file changed, 214 insertions(+), 104 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index ba057ba..d6b1cb1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import List, Optional, Union +from typing import List, Optional, Union, Callable import numpy as np import torch @@ -136,11 +136,165 @@ class VlpnStableDiffusion(DiffusionPipeline): if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + @property + def execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): + if isinstance(prompt, str): + prompt = [prompt] + + if negative_prompt is None: + negative_prompt = "" + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) + + if not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if not isinstance(negative_prompt, list): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if len(negative_prompt) != len(prompt): + raise ValueError( + f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + return prompt, negative_prompt + + def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): + text_input_ids = self.prompt_processor.get_input_ids(prompt) + text_input_ids *= num_images_per_prompt + + if do_classifier_free_guidance: + unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) + unconditional_input_ids *= num_images_per_prompt + text_input_ids = unconditional_input_ids + text_input_ids + + text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) + text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) + + return text_embeddings + + def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): + if latents_are_image: + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + else: + timesteps = self.scheduler.timesteps + + timesteps = timesteps.to(device) + + return timesteps + + def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + return latents + + def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + init_image = init_image.to(device=device, dtype=dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + @torch.no_grad() def __call__( self, prompt: Union[str, List[str], List[List[str]]], negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, + num_images_per_prompt: Optional[int] = 1, strength: float = 0.8, height: Optional[int] = 512, width: Optional[int] = 512, @@ -148,9 +302,11 @@ class VlpnStableDiffusion(DiffusionPipeline): guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, - latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, ): r""" Function invoked when calling the pipeline for generation. @@ -202,110 +358,60 @@ class VlpnStableDiffusion(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): - prompt = [prompt] + # 1. Check inputs. Raise error if not correct + prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) + # 2. Define call parameters batch_size = len(prompt) - - if negative_prompt is None: - negative_prompt = "" - - if isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - - if len(negative_prompt) != len(prompt): - raise ValueError( - f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if strength < 0 or strength > 1: - raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - # get prompt text embeddings - text_input_ids = self.prompt_processor.get_input_ids(prompt) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. + device = self.execution_device do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) - text_input_ids = unconditional_input_ids + text_input_ids - - text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) - text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) - - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = num_inference_steps + offset - - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_dtype = text_embeddings.dtype - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_are_image = isinstance(latents_or_image, PIL.Image.Image) - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( - self.device - ) - else: - latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) - elif isinstance(latents, PIL.Image.Image): - latents = preprocess(latents, width, height) - latents = latents.to(device=self.device, dtype=latents_dtype) - latent_dist = self.vae.encode(latents).latent_dist - latents = latent_dist.sample(generator=generator) - latents = 0.18215 * latents - - # expand init_latents for batch_size - latents = torch.cat([latents] * batch_size, dim=0) - - # get the original timestep using init_timestep - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) + print(f">>> {device}") - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, device=self.device) + # 3. Encode input prompt + text_embeddings = self.encode_prompt( + prompt, + negative_prompt, + num_images_per_prompt, + do_classifier_free_guidance + ) - # add noise to latents using the timesteps - noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) - latents = self.scheduler.add_noise(latents, noise, timesteps) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.get_timesteps(latents_are_image, num_inference_steps, strength, device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + if latents_are_image: + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents = self.prepare_latents_from_image( + latents_or_image, + latent_timestep, + batch_size, + num_images_per_prompt, + text_embeddings.dtype, + device, + generator + ) else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - if latents.device != self.device: - raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") - - t_start = max(num_inference_steps - init_timestep + offset, 0) - - # Some schedulers like PNDM have timesteps as arrays - # It's more optimzed to move all timesteps to correct device beforehand - timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) + latents = self.prepare_latents( + batch_size, + num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents_or_image, + ) - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys()) - accepts_eta = "eta" in scheduler_step_args - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - accepts_generator = "generator" in scheduler_step_args - if generator is not None and accepts_generator: - extra_step_kwargs["generator"] = generator + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -321,17 +427,21 @@ class VlpnStableDiffusion(DiffusionPipeline): # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + has_nsfw_concept = None + # 10. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: - return (image, None) + return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) -- cgit v1.2.3-54-g00ecf