diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 316 |
1 files changed, 213 insertions, 103 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index ba057ba..d6b1cb1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -1,6 +1,6 @@ | |||
| 1 | import inspect | 1 | import inspect |
| 2 | import warnings | 2 | import warnings |
| 3 | from typing import List, Optional, Union | 3 | from typing import List, Optional, Union, Callable |
| 4 | 4 | ||
| 5 | import numpy as np | 5 | import numpy as np |
| 6 | import torch | 6 | import torch |
| @@ -136,11 +136,165 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 136 | if cpu_offloaded_model is not None: | 136 | if cpu_offloaded_model is not None: |
| 137 | cpu_offload(cpu_offloaded_model, device) | 137 | cpu_offload(cpu_offloaded_model, device) |
| 138 | 138 | ||
| 139 | @property | ||
| 140 | def execution_device(self): | ||
| 141 | r""" | ||
| 142 | Returns the device on which the pipeline's models will be executed. After calling | ||
| 143 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module | ||
| 144 | hooks. | ||
| 145 | """ | ||
| 146 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): | ||
| 147 | return self.device | ||
| 148 | for module in self.unet.modules(): | ||
| 149 | if ( | ||
| 150 | hasattr(module, "_hf_hook") | ||
| 151 | and hasattr(module._hf_hook, "execution_device") | ||
| 152 | and module._hf_hook.execution_device is not None | ||
| 153 | ): | ||
| 154 | return torch.device(module._hf_hook.execution_device) | ||
| 155 | return self.device | ||
| 156 | |||
| 157 | def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): | ||
| 158 | if isinstance(prompt, str): | ||
| 159 | prompt = [prompt] | ||
| 160 | |||
| 161 | if negative_prompt is None: | ||
| 162 | negative_prompt = "" | ||
| 163 | |||
| 164 | if isinstance(negative_prompt, str): | ||
| 165 | negative_prompt = [negative_prompt] * len(prompt) | ||
| 166 | |||
| 167 | if not isinstance(prompt, list): | ||
| 168 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
| 169 | |||
| 170 | if not isinstance(negative_prompt, list): | ||
| 171 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | ||
| 172 | |||
| 173 | if len(negative_prompt) != len(prompt): | ||
| 174 | raise ValueError( | ||
| 175 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | ||
| 176 | |||
| 177 | if strength < 0 or strength > 1: | ||
| 178 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | ||
| 179 | |||
| 180 | if height % 8 != 0 or width % 8 != 0: | ||
| 181 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
| 182 | |||
| 183 | if (callback_steps is None) or ( | ||
| 184 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) | ||
| 185 | ): | ||
| 186 | raise ValueError( | ||
| 187 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | ||
| 188 | f" {type(callback_steps)}." | ||
| 189 | ) | ||
| 190 | |||
| 191 | return prompt, negative_prompt | ||
| 192 | |||
| 193 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): | ||
| 194 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | ||
| 195 | text_input_ids *= num_images_per_prompt | ||
| 196 | |||
| 197 | if do_classifier_free_guidance: | ||
| 198 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) | ||
| 199 | unconditional_input_ids *= num_images_per_prompt | ||
| 200 | text_input_ids = unconditional_input_ids + text_input_ids | ||
| 201 | |||
| 202 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) | ||
| 203 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) | ||
| 204 | |||
| 205 | return text_embeddings | ||
| 206 | |||
| 207 | def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): | ||
| 208 | if latents_are_image: | ||
| 209 | # get the original timestep using init_timestep | ||
| 210 | offset = self.scheduler.config.get("steps_offset", 0) | ||
| 211 | init_timestep = int(num_inference_steps * strength) + offset | ||
| 212 | init_timestep = min(init_timestep, num_inference_steps) | ||
| 213 | |||
| 214 | t_start = max(num_inference_steps - init_timestep + offset, 0) | ||
| 215 | timesteps = self.scheduler.timesteps[t_start:] | ||
| 216 | else: | ||
| 217 | timesteps = self.scheduler.timesteps | ||
| 218 | |||
| 219 | timesteps = timesteps.to(device) | ||
| 220 | |||
| 221 | return timesteps | ||
| 222 | |||
| 223 | def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): | ||
| 224 | shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) | ||
| 225 | |||
| 226 | if latents is None: | ||
| 227 | if device.type == "mps": | ||
| 228 | # randn does not work reproducibly on mps | ||
| 229 | latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) | ||
| 230 | else: | ||
| 231 | latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) | ||
| 232 | else: | ||
| 233 | if latents.shape != shape: | ||
| 234 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | ||
| 235 | latents = latents.to(device) | ||
| 236 | |||
| 237 | # scale the initial noise by the standard deviation required by the scheduler | ||
| 238 | latents = latents * self.scheduler.init_noise_sigma | ||
| 239 | |||
| 240 | return latents | ||
| 241 | |||
| 242 | def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): | ||
| 243 | init_image = init_image.to(device=device, dtype=dtype) | ||
| 244 | init_latent_dist = self.vae.encode(init_image).latent_dist | ||
| 245 | init_latents = init_latent_dist.sample(generator=generator) | ||
| 246 | init_latents = 0.18215 * init_latents | ||
| 247 | |||
| 248 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: | ||
| 249 | additional_image_per_prompt = batch_size // init_latents.shape[0] | ||
| 250 | init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) | ||
| 251 | elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: | ||
| 252 | raise ValueError( | ||
| 253 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | ||
| 254 | ) | ||
| 255 | else: | ||
| 256 | init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) | ||
| 257 | |||
| 258 | # add noise to latents using the timesteps | ||
| 259 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) | ||
| 260 | |||
| 261 | # get latents | ||
| 262 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep) | ||
| 263 | latents = init_latents | ||
| 264 | |||
| 265 | return latents | ||
| 266 | |||
| 267 | def prepare_extra_step_kwargs(self, generator, eta): | ||
| 268 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
| 269 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
| 270 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
| 271 | # and should be between [0, 1] | ||
| 272 | |||
| 273 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
| 274 | extra_step_kwargs = {} | ||
| 275 | if accepts_eta: | ||
| 276 | extra_step_kwargs["eta"] = eta | ||
| 277 | |||
| 278 | # check if the scheduler accepts generator | ||
| 279 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
| 280 | if accepts_generator: | ||
| 281 | extra_step_kwargs["generator"] = generator | ||
| 282 | return extra_step_kwargs | ||
| 283 | |||
| 284 | def decode_latents(self, latents): | ||
| 285 | latents = 1 / 0.18215 * latents | ||
| 286 | image = self.vae.decode(latents).sample | ||
| 287 | image = (image / 2 + 0.5).clamp(0, 1) | ||
| 288 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | ||
| 289 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | ||
| 290 | return image | ||
| 291 | |||
| 139 | @torch.no_grad() | 292 | @torch.no_grad() |
| 140 | def __call__( | 293 | def __call__( |
| 141 | self, | 294 | self, |
| 142 | prompt: Union[str, List[str], List[List[str]]], | 295 | prompt: Union[str, List[str], List[List[str]]], |
| 143 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, | 296 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, |
| 297 | num_images_per_prompt: Optional[int] = 1, | ||
| 144 | strength: float = 0.8, | 298 | strength: float = 0.8, |
| 145 | height: Optional[int] = 512, | 299 | height: Optional[int] = 512, |
| 146 | width: Optional[int] = 512, | 300 | width: Optional[int] = 512, |
| @@ -148,9 +302,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 148 | guidance_scale: Optional[float] = 7.5, | 302 | guidance_scale: Optional[float] = 7.5, |
| 149 | eta: Optional[float] = 0.0, | 303 | eta: Optional[float] = 0.0, |
| 150 | generator: Optional[torch.Generator] = None, | 304 | generator: Optional[torch.Generator] = None, |
| 151 | latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 305 | latents_or_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
| 152 | output_type: Optional[str] = "pil", | 306 | output_type: Optional[str] = "pil", |
| 153 | return_dict: bool = True, | 307 | return_dict: bool = True, |
| 308 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | ||
| 309 | callback_steps: Optional[int] = 1, | ||
| 154 | ): | 310 | ): |
| 155 | r""" | 311 | r""" |
| 156 | Function invoked when calling the pipeline for generation. | 312 | Function invoked when calling the pipeline for generation. |
| @@ -202,110 +358,60 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 202 | (nsfw) content, according to the `safety_checker`. | 358 | (nsfw) content, according to the `safety_checker`. |
| 203 | """ | 359 | """ |
| 204 | 360 | ||
| 205 | if isinstance(prompt, str): | 361 | # 1. Check inputs. Raise error if not correct |
| 206 | prompt = [prompt] | 362 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) |
| 207 | 363 | ||
| 364 | # 2. Define call parameters | ||
| 208 | batch_size = len(prompt) | 365 | batch_size = len(prompt) |
| 209 | 366 | device = self.execution_device | |
| 210 | if negative_prompt is None: | ||
| 211 | negative_prompt = "" | ||
| 212 | |||
| 213 | if isinstance(negative_prompt, str): | ||
| 214 | negative_prompt = [negative_prompt] * batch_size | ||
| 215 | |||
| 216 | if len(negative_prompt) != len(prompt): | ||
| 217 | raise ValueError( | ||
| 218 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | ||
| 219 | |||
| 220 | if height % 8 != 0 or width % 8 != 0: | ||
| 221 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
| 222 | |||
| 223 | if strength < 0 or strength > 1: | ||
| 224 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | ||
| 225 | |||
| 226 | # set timesteps | ||
| 227 | self.scheduler.set_timesteps(num_inference_steps) | ||
| 228 | |||
| 229 | # get prompt text embeddings | ||
| 230 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | ||
| 231 | |||
| 232 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
| 233 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||
| 234 | # corresponds to doing no classifier free guidance. | ||
| 235 | do_classifier_free_guidance = guidance_scale > 1.0 | 367 | do_classifier_free_guidance = guidance_scale > 1.0 |
| 236 | # get unconditional embeddings for classifier free guidance | 368 | latents_are_image = isinstance(latents_or_image, PIL.Image.Image) |
| 237 | if do_classifier_free_guidance: | ||
| 238 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) | ||
| 239 | text_input_ids = unconditional_input_ids + text_input_ids | ||
| 240 | |||
| 241 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) | ||
| 242 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) | ||
| 243 | |||
| 244 | offset = self.scheduler.config.get("steps_offset", 0) | ||
| 245 | init_timestep = num_inference_steps + offset | ||
| 246 | |||
| 247 | # get the initial random noise unless the user supplied it | ||
| 248 | 369 | ||
| 249 | # Unlike in other pipelines, latents need to be generated in the target device | 370 | print(f">>> {device}") |
| 250 | # for 1-to-1 results reproducibility with the CompVis implementation. | ||
| 251 | # However this currently doesn't work in `mps`. | ||
| 252 | latents_dtype = text_embeddings.dtype | ||
| 253 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) | ||
| 254 | 371 | ||
| 255 | if latents is None: | 372 | # 3. Encode input prompt |
| 256 | if self.device.type == "mps": | 373 | text_embeddings = self.encode_prompt( |
| 257 | # randn does not exist on mps | 374 | prompt, |
| 258 | latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( | 375 | negative_prompt, |
| 259 | self.device | 376 | num_images_per_prompt, |
| 260 | ) | 377 | do_classifier_free_guidance |
| 261 | else: | 378 | ) |
| 262 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | ||
| 263 | elif isinstance(latents, PIL.Image.Image): | ||
| 264 | latents = preprocess(latents, width, height) | ||
| 265 | latents = latents.to(device=self.device, dtype=latents_dtype) | ||
| 266 | latent_dist = self.vae.encode(latents).latent_dist | ||
| 267 | latents = latent_dist.sample(generator=generator) | ||
| 268 | latents = 0.18215 * latents | ||
| 269 | |||
| 270 | # expand init_latents for batch_size | ||
| 271 | latents = torch.cat([latents] * batch_size, dim=0) | ||
| 272 | |||
| 273 | # get the original timestep using init_timestep | ||
| 274 | init_timestep = int(num_inference_steps * strength) + offset | ||
| 275 | init_timestep = min(init_timestep, num_inference_steps) | ||
| 276 | 379 | ||
| 277 | timesteps = self.scheduler.timesteps[-init_timestep] | 380 | # 4. Prepare timesteps |
| 278 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | 381 | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| 382 | timesteps = self.get_timesteps(latents_are_image, num_inference_steps, strength, device) | ||
| 279 | 383 | ||
| 280 | # add noise to latents using the timesteps | 384 | # 5. Prepare latent variables |
| 281 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) | 385 | num_channels_latents = self.unet.in_channels |
| 282 | latents = self.scheduler.add_noise(latents, noise, timesteps) | 386 | if latents_are_image: |
| 387 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | ||
| 388 | latents = self.prepare_latents_from_image( | ||
| 389 | latents_or_image, | ||
| 390 | latent_timestep, | ||
| 391 | batch_size, | ||
| 392 | num_images_per_prompt, | ||
| 393 | text_embeddings.dtype, | ||
| 394 | device, | ||
| 395 | generator | ||
| 396 | ) | ||
| 283 | else: | 397 | else: |
| 284 | if latents.shape != latents_shape: | 398 | latents = self.prepare_latents( |
| 285 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 399 | batch_size, |
| 286 | if latents.device != self.device: | 400 | num_images_per_prompt, |
| 287 | raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") | 401 | num_channels_latents, |
| 288 | 402 | height, | |
| 289 | t_start = max(num_inference_steps - init_timestep + offset, 0) | 403 | width, |
| 290 | 404 | text_embeddings.dtype, | |
| 291 | # Some schedulers like PNDM have timesteps as arrays | 405 | device, |
| 292 | # It's more optimzed to move all timesteps to correct device beforehand | 406 | generator, |
| 293 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) | 407 | latents_or_image, |
| 408 | ) | ||
| 294 | 409 | ||
| 295 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | 410 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
| 296 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | 411 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| 297 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
| 298 | # and should be between [0, 1] | ||
| 299 | scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
| 300 | accepts_eta = "eta" in scheduler_step_args | ||
| 301 | extra_step_kwargs = {} | ||
| 302 | if accepts_eta: | ||
| 303 | extra_step_kwargs["eta"] = eta | ||
| 304 | accepts_generator = "generator" in scheduler_step_args | ||
| 305 | if generator is not None and accepts_generator: | ||
| 306 | extra_step_kwargs["generator"] = generator | ||
| 307 | 412 | ||
| 308 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 413 | # 7. Denoising loop |
| 414 | for i, t in enumerate(self.progress_bar(timesteps)): | ||
| 309 | # expand the latents if we are doing classifier free guidance | 415 | # expand the latents if we are doing classifier free guidance |
| 310 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 416 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| 311 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 417 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| @@ -321,17 +427,21 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 321 | # compute the previous noisy sample x_t -> x_t-1 | 427 | # compute the previous noisy sample x_t -> x_t-1 |
| 322 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 428 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
| 323 | 429 | ||
| 324 | # scale and decode the image latents with vae | 430 | # call the callback, if provided |
| 325 | latents = 1 / 0.18215 * latents | 431 | if callback is not None and i % callback_steps == 0: |
| 326 | image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample | 432 | callback(i, t, latents) |
| 327 | 433 | ||
| 328 | image = (image / 2 + 0.5).clamp(0, 1) | 434 | # 8. Post-processing |
| 329 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 435 | image = self.decode_latents(latents) |
| 436 | |||
| 437 | # 9. Run safety checker | ||
| 438 | has_nsfw_concept = None | ||
| 330 | 439 | ||
| 440 | # 10. Convert to PIL | ||
| 331 | if output_type == "pil": | 441 | if output_type == "pil": |
| 332 | image = self.numpy_to_pil(image) | 442 | image = self.numpy_to_pil(image) |
| 333 | 443 | ||
| 334 | if not return_dict: | 444 | if not return_dict: |
| 335 | return (image, None) | 445 | return (image, has_nsfw_concept) |
| 336 | 446 | ||
| 337 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) | 447 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
