diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 262 |
1 files changed, 188 insertions, 74 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index aa446ec..16b8456 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -21,7 +21,9 @@ from diffusers import ( | |||
| 21 | LMSDiscreteScheduler, | 21 | LMSDiscreteScheduler, |
| 22 | PNDMScheduler, | 22 | PNDMScheduler, |
| 23 | ) | 23 | ) |
| 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( |
| 25 | StableDiffusionPipelineOutput, | ||
| 26 | ) | ||
| 25 | from diffusers.utils import logging, randn_tensor | 27 | from diffusers.utils import logging, randn_tensor |
| 26 | from transformers import CLIPTextModel, CLIPTokenizer | 28 | from transformers import CLIPTextModel, CLIPTokenizer |
| 27 | 29 | ||
| @@ -62,13 +64,35 @@ def gaussian_blur_2d(img, kernel_size, sigma): | |||
| 62 | return img | 64 | return img |
| 63 | 65 | ||
| 64 | 66 | ||
| 67 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | ||
| 68 | """ | ||
| 69 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | ||
| 70 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | ||
| 71 | """ | ||
| 72 | std_text = noise_pred_text.std( | ||
| 73 | dim=list(range(1, noise_pred_text.ndim)), keepdim=True | ||
| 74 | ) | ||
| 75 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | ||
| 76 | # rescale the results from guidance (fixes overexposure) | ||
| 77 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | ||
| 78 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | ||
| 79 | noise_cfg = ( | ||
| 80 | guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | ||
| 81 | ) | ||
| 82 | return noise_cfg | ||
| 83 | |||
| 84 | |||
| 65 | class CrossAttnStoreProcessor: | 85 | class CrossAttnStoreProcessor: |
| 66 | def __init__(self): | 86 | def __init__(self): |
| 67 | self.attention_probs = None | 87 | self.attention_probs = None |
| 68 | 88 | ||
| 69 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): | 89 | def __call__( |
| 90 | self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None | ||
| 91 | ): | ||
| 70 | batch_size, sequence_length, _ = hidden_states.shape | 92 | batch_size, sequence_length, _ = hidden_states.shape |
| 71 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | 93 | attention_mask = attn.prepare_attention_mask( |
| 94 | attention_mask, sequence_length, batch_size | ||
| 95 | ) | ||
| 72 | query = attn.to_q(hidden_states) | 96 | query = attn.to_q(hidden_states) |
| 73 | 97 | ||
| 74 | if encoder_hidden_states is None: | 98 | if encoder_hidden_states is None: |
| @@ -113,7 +137,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 113 | ): | 137 | ): |
| 114 | super().__init__() | 138 | super().__init__() |
| 115 | 139 | ||
| 116 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | 140 | if ( |
| 141 | hasattr(scheduler.config, "steps_offset") | ||
| 142 | and scheduler.config.steps_offset != 1 | ||
| 143 | ): | ||
| 117 | warnings.warn( | 144 | warnings.warn( |
| 118 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | 145 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" |
| 119 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | 146 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " |
| @@ -179,7 +206,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 179 | 206 | ||
| 180 | device = torch.device("cuda") | 207 | device = torch.device("cuda") |
| 181 | 208 | ||
| 182 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: | 209 | for cpu_offloaded_model in [ |
| 210 | self.unet, | ||
| 211 | self.text_encoder, | ||
| 212 | self.vae, | ||
| 213 | self.safety_checker, | ||
| 214 | ]: | ||
| 183 | if cpu_offloaded_model is not None: | 215 | if cpu_offloaded_model is not None: |
| 184 | cpu_offload(cpu_offloaded_model, device) | 216 | cpu_offload(cpu_offloaded_model, device) |
| 185 | 217 | ||
| @@ -223,35 +255,47 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 223 | width: int, | 255 | width: int, |
| 224 | height: int, | 256 | height: int, |
| 225 | strength: float, | 257 | strength: float, |
| 226 | callback_steps: Optional[int] | 258 | callback_steps: Optional[int], |
| 227 | ): | 259 | ): |
| 228 | if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): | 260 | if isinstance(prompt, str) or ( |
| 261 | isinstance(prompt, list) and isinstance(prompt[0], int) | ||
| 262 | ): | ||
| 229 | prompt = [prompt] | 263 | prompt = [prompt] |
| 230 | 264 | ||
| 231 | if negative_prompt is None: | 265 | if negative_prompt is None: |
| 232 | negative_prompt = "" | 266 | negative_prompt = "" |
| 233 | 267 | ||
| 234 | if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): | 268 | if isinstance(negative_prompt, str) or ( |
| 269 | isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int) | ||
| 270 | ): | ||
| 235 | negative_prompt = [negative_prompt] * len(prompt) | 271 | negative_prompt = [negative_prompt] * len(prompt) |
| 236 | 272 | ||
| 237 | if not isinstance(prompt, list): | 273 | if not isinstance(prompt, list): |
| 238 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | 274 | raise ValueError( |
| 275 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" | ||
| 276 | ) | ||
| 239 | 277 | ||
| 240 | if not isinstance(negative_prompt, list): | 278 | if not isinstance(negative_prompt, list): |
| 241 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | 279 | raise ValueError( |
| 280 | f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}" | ||
| 281 | ) | ||
| 242 | 282 | ||
| 243 | if len(negative_prompt) != len(prompt): | 283 | if len(negative_prompt) != len(prompt): |
| 244 | raise ValueError( | 284 | raise ValueError( |
| 245 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | 285 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}" |
| 286 | ) | ||
| 246 | 287 | ||
| 247 | if strength < 0 or strength > 1: | 288 | if strength < 0 or strength > 1: |
| 248 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | 289 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") |
| 249 | 290 | ||
| 250 | if height % 8 != 0 or width % 8 != 0: | 291 | if height % 8 != 0 or width % 8 != 0: |
| 251 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 292 | raise ValueError( |
| 293 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}." | ||
| 294 | ) | ||
| 252 | 295 | ||
| 253 | if (callback_steps is None) or ( | 296 | if (callback_steps is None) or ( |
| 254 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) | 297 | callback_steps is not None |
| 298 | and (not isinstance(callback_steps, int) or callback_steps <= 0) | ||
| 255 | ): | 299 | ): |
| 256 | raise ValueError( | 300 | raise ValueError( |
| 257 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | 301 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
| @@ -266,7 +310,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 266 | negative_prompt: Union[List[str], List[List[int]]], | 310 | negative_prompt: Union[List[str], List[List[int]]], |
| 267 | num_images_per_prompt: int, | 311 | num_images_per_prompt: int, |
| 268 | do_classifier_free_guidance: bool, | 312 | do_classifier_free_guidance: bool, |
| 269 | device | 313 | device, |
| 270 | ): | 314 | ): |
| 271 | if isinstance(prompt[0], str): | 315 | if isinstance(prompt[0], str): |
| 272 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids | 316 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids |
| @@ -277,7 +321,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 277 | 321 | ||
| 278 | if do_classifier_free_guidance: | 322 | if do_classifier_free_guidance: |
| 279 | if isinstance(prompt[0], str): | 323 | if isinstance(prompt[0], str): |
| 280 | unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids | 324 | unconditional_input_ids = self.tokenizer( |
| 325 | negative_prompt, padding="do_not_pad" | ||
| 326 | ).input_ids | ||
| 281 | else: | 327 | else: |
| 282 | unconditional_input_ids = negative_prompt | 328 | unconditional_input_ids = negative_prompt |
| 283 | unconditional_input_ids *= num_images_per_prompt | 329 | unconditional_input_ids *= num_images_per_prompt |
| @@ -286,12 +332,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 286 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) | 332 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) |
| 287 | text_input_ids = text_inputs.input_ids | 333 | text_input_ids = text_inputs.input_ids |
| 288 | 334 | ||
| 289 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | 335 | if ( |
| 336 | hasattr(self.text_encoder.config, "use_attention_mask") | ||
| 337 | and self.text_encoder.config.use_attention_mask | ||
| 338 | ): | ||
| 290 | attention_mask = text_inputs.attention_mask.to(device) | 339 | attention_mask = text_inputs.attention_mask.to(device) |
| 291 | else: | 340 | else: |
| 292 | attention_mask = None | 341 | attention_mask = None |
| 293 | 342 | ||
| 294 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) | 343 | prompt_embeds = get_extended_embeddings( |
| 344 | self.text_encoder, text_input_ids.to(device), attention_mask | ||
| 345 | ) | ||
| 295 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | 346 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) |
| 296 | 347 | ||
| 297 | return prompt_embeds | 348 | return prompt_embeds |
| @@ -301,25 +352,21 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 301 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | 352 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
| 302 | 353 | ||
| 303 | t_start = max(num_inference_steps - init_timestep, 0) | 354 | t_start = max(num_inference_steps - init_timestep, 0) |
| 304 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] | 355 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] |
| 305 | 356 | ||
| 306 | timesteps = timesteps.to(device) | 357 | timesteps = timesteps.to(device) |
| 307 | 358 | ||
| 308 | return timesteps, num_inference_steps - t_start | 359 | return timesteps, num_inference_steps - t_start |
| 309 | 360 | ||
| 310 | def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): | 361 | def prepare_latents_from_image( |
| 311 | offset_image = perlin_noise( | 362 | self, |
| 312 | (batch_size, 1, width, height), | 363 | init_image, |
| 313 | res=1, | 364 | timestep, |
| 314 | generator=generator, | 365 | batch_size, |
| 315 | dtype=dtype, | 366 | dtype, |
| 316 | device=device | 367 | device, |
| 317 | ) | 368 | generator=None, |
| 318 | offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) | 369 | ): |
| 319 | offset_latents = self.vae.config.scaling_factor * offset_latents | ||
| 320 | return offset_latents | ||
| 321 | |||
| 322 | def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None): | ||
| 323 | init_image = init_image.to(device=device, dtype=dtype) | 370 | init_image = init_image.to(device=device, dtype=dtype) |
| 324 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) | 371 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) |
| 325 | latents = self.vae.config.scaling_factor * latents | 372 | latents = self.vae.config.scaling_factor * latents |
| @@ -333,20 +380,32 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 333 | latents = torch.cat([latents] * batch_multiplier, dim=0) | 380 | latents = torch.cat([latents] * batch_multiplier, dim=0) |
| 334 | 381 | ||
| 335 | # add noise to latents using the timesteps | 382 | # add noise to latents using the timesteps |
| 336 | noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) | 383 | noise = torch.randn( |
| 337 | 384 | latents.shape, generator=generator, device=device, dtype=dtype | |
| 338 | if brightness_offset != 0: | 385 | ) |
| 339 | noise += brightness_offset * self.prepare_brightness_offset( | ||
| 340 | batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator | ||
| 341 | ) | ||
| 342 | 386 | ||
| 343 | # get latents | 387 | # get latents |
| 344 | latents = self.scheduler.add_noise(latents, noise, timestep) | 388 | latents = self.scheduler.add_noise(latents, noise, timestep) |
| 345 | 389 | ||
| 346 | return latents | 390 | return latents |
| 347 | 391 | ||
| 348 | def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): | 392 | def prepare_latents( |
| 349 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | 393 | self, |
| 394 | batch_size, | ||
| 395 | num_channels_latents, | ||
| 396 | height, | ||
| 397 | width, | ||
| 398 | dtype, | ||
| 399 | device, | ||
| 400 | generator, | ||
| 401 | latents=None, | ||
| 402 | ): | ||
| 403 | shape = ( | ||
| 404 | batch_size, | ||
| 405 | num_channels_latents, | ||
| 406 | height // self.vae_scale_factor, | ||
| 407 | width // self.vae_scale_factor, | ||
| 408 | ) | ||
| 350 | if isinstance(generator, list) and len(generator) != batch_size: | 409 | if isinstance(generator, list) and len(generator) != batch_size: |
| 351 | raise ValueError( | 410 | raise ValueError( |
| 352 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | 411 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| @@ -354,15 +413,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 354 | ) | 413 | ) |
| 355 | 414 | ||
| 356 | if latents is None: | 415 | if latents is None: |
| 357 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | 416 | latents = randn_tensor( |
| 417 | shape, generator=generator, device=device, dtype=dtype | ||
| 418 | ) | ||
| 358 | else: | 419 | else: |
| 359 | latents = latents.to(device) | 420 | latents = latents.to(device) |
| 360 | 421 | ||
| 361 | if brightness_offset != 0: | ||
| 362 | latents += brightness_offset * self.prepare_brightness_offset( | ||
| 363 | batch_size, height, width, dtype, device, generator | ||
| 364 | ) | ||
| 365 | |||
| 366 | # scale the initial noise by the standard deviation required by the scheduler | 422 | # scale the initial noise by the standard deviation required by the scheduler |
| 367 | latents = latents * self.scheduler.init_noise_sigma | 423 | latents = latents * self.scheduler.init_noise_sigma |
| 368 | return latents | 424 | return latents |
| @@ -373,13 +429,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 373 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | 429 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
| 374 | # and should be between [0, 1] | 430 | # and should be between [0, 1] |
| 375 | 431 | ||
| 376 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 432 | accepts_eta = "eta" in set( |
| 433 | inspect.signature(self.scheduler.step).parameters.keys() | ||
| 434 | ) | ||
| 377 | extra_step_kwargs = {} | 435 | extra_step_kwargs = {} |
| 378 | if accepts_eta: | 436 | if accepts_eta: |
| 379 | extra_step_kwargs["eta"] = eta | 437 | extra_step_kwargs["eta"] = eta |
| 380 | 438 | ||
| 381 | # check if the scheduler accepts generator | 439 | # check if the scheduler accepts generator |
| 382 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 440 | accepts_generator = "generator" in set( |
| 441 | inspect.signature(self.scheduler.step).parameters.keys() | ||
| 442 | ) | ||
| 383 | if accepts_generator: | 443 | if accepts_generator: |
| 384 | extra_step_kwargs["generator"] = generator | 444 | extra_step_kwargs["generator"] = generator |
| 385 | return extra_step_kwargs | 445 | return extra_step_kwargs |
| @@ -396,7 +456,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 396 | def __call__( | 456 | def __call__( |
| 397 | self, | 457 | self, |
| 398 | prompt: Union[str, List[str], List[int], List[List[int]]], | 458 | prompt: Union[str, List[str], List[int], List[List[int]]], |
| 399 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, | 459 | negative_prompt: Optional[ |
| 460 | Union[str, List[str], List[int], List[List[int]]] | ||
| 461 | ] = None, | ||
| 400 | num_images_per_prompt: int = 1, | 462 | num_images_per_prompt: int = 1, |
| 401 | strength: float = 1.0, | 463 | strength: float = 1.0, |
| 402 | height: Optional[int] = None, | 464 | height: Optional[int] = None, |
| @@ -407,12 +469,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 407 | eta: float = 0.0, | 469 | eta: float = 0.0, |
| 408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 470 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 471 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
| 410 | brightness_offset: Union[float, torch.FloatTensor] = 0, | ||
| 411 | output_type: str = "pil", | 472 | output_type: str = "pil", |
| 412 | return_dict: bool = True, | 473 | return_dict: bool = True, |
| 413 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 474 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| 414 | callback_steps: int = 1, | 475 | callback_steps: int = 1, |
| 415 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, | 476 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| 477 | guidance_rescale: float = 0.0, | ||
| 416 | ): | 478 | ): |
| 417 | r""" | 479 | r""" |
| 418 | Function invoked when calling the pipeline for generation. | 480 | Function invoked when calling the pipeline for generation. |
| @@ -472,7 +534,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 472 | width = width or self.unet.config.sample_size * self.vae_scale_factor | 534 | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| 473 | 535 | ||
| 474 | # 1. Check inputs. Raise error if not correct | 536 | # 1. Check inputs. Raise error if not correct |
| 475 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) | 537 | prompt, negative_prompt = self.check_inputs( |
| 538 | prompt, negative_prompt, width, height, strength, callback_steps | ||
| 539 | ) | ||
| 476 | 540 | ||
| 477 | # 2. Define call parameters | 541 | # 2. Define call parameters |
| 478 | batch_size = len(prompt) | 542 | batch_size = len(prompt) |
| @@ -488,7 +552,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 488 | negative_prompt, | 552 | negative_prompt, |
| 489 | num_images_per_prompt, | 553 | num_images_per_prompt, |
| 490 | do_classifier_free_guidance, | 554 | do_classifier_free_guidance, |
| 491 | device | 555 | device, |
| 492 | ) | 556 | ) |
| 493 | 557 | ||
| 494 | # 4. Prepare latent variables | 558 | # 4. Prepare latent variables |
| @@ -497,7 +561,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 497 | 561 | ||
| 498 | # 5. Prepare timesteps | 562 | # 5. Prepare timesteps |
| 499 | self.scheduler.set_timesteps(num_inference_steps, device=device) | 563 | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| 500 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | 564 | timesteps, num_inference_steps = self.get_timesteps( |
| 565 | num_inference_steps, strength, device | ||
| 566 | ) | ||
| 501 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 567 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
| 502 | 568 | ||
| 503 | # 6. Prepare latent variables | 569 | # 6. Prepare latent variables |
| @@ -506,7 +572,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 506 | image, | 572 | image, |
| 507 | latent_timestep, | 573 | latent_timestep, |
| 508 | batch_size * num_images_per_prompt, | 574 | batch_size * num_images_per_prompt, |
| 509 | brightness_offset, | ||
| 510 | prompt_embeds.dtype, | 575 | prompt_embeds.dtype, |
| 511 | device, | 576 | device, |
| 512 | generator, | 577 | generator, |
| @@ -517,7 +582,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 517 | num_channels_latents, | 582 | num_channels_latents, |
| 518 | height, | 583 | height, |
| 519 | width, | 584 | width, |
| 520 | brightness_offset, | ||
| 521 | prompt_embeds.dtype, | 585 | prompt_embeds.dtype, |
| 522 | device, | 586 | device, |
| 523 | generator, | 587 | generator, |
| @@ -530,14 +594,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 530 | # 8. Denoising loo | 594 | # 8. Denoising loo |
| 531 | if do_self_attention_guidance: | 595 | if do_self_attention_guidance: |
| 532 | store_processor = CrossAttnStoreProcessor() | 596 | store_processor = CrossAttnStoreProcessor() |
| 533 | self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor | 597 | self.unet.mid_block.attentions[0].transformer_blocks[ |
| 598 | 0 | ||
| 599 | ].attn1.processor = store_processor | ||
| 534 | 600 | ||
| 535 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | 601 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| 536 | with self.progress_bar(total=num_inference_steps) as progress_bar: | 602 | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| 537 | for i, t in enumerate(timesteps): | 603 | for i, t in enumerate(timesteps): |
| 538 | # expand the latents if we are doing classifier free guidance | 604 | # expand the latents if we are doing classifier free guidance |
| 539 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 605 | latent_model_input = ( |
| 540 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 606 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| 607 | ) | ||
| 608 | latent_model_input = self.scheduler.scale_model_input( | ||
| 609 | latent_model_input, t | ||
| 610 | ) | ||
| 541 | 611 | ||
| 542 | # predict the noise residual | 612 | # predict the noise residual |
| 543 | noise_pred = self.unet( | 613 | noise_pred = self.unet( |
| @@ -551,7 +621,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 551 | # perform guidance | 621 | # perform guidance |
| 552 | if do_classifier_free_guidance: | 622 | if do_classifier_free_guidance: |
| 553 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 623 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| 554 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 624 | noise_pred = noise_pred_uncond + guidance_scale * ( |
| 625 | noise_pred_text - noise_pred_uncond | ||
| 626 | ) | ||
| 627 | noise_pred = rescale_noise_cfg( | ||
| 628 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale | ||
| 629 | ) | ||
| 555 | 630 | ||
| 556 | if do_self_attention_guidance: | 631 | if do_self_attention_guidance: |
| 557 | # classifier-free guidance produces two chunks of attention map | 632 | # classifier-free guidance produces two chunks of attention map |
| @@ -561,15 +636,24 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 561 | # DDIM-like prediction of x0 | 636 | # DDIM-like prediction of x0 |
| 562 | pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) | 637 | pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) |
| 563 | # get the stored attention maps | 638 | # get the stored attention maps |
| 564 | uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) | 639 | uncond_attn, cond_attn = store_processor.attention_probs.chunk( |
| 640 | 2 | ||
| 641 | ) | ||
| 565 | # self-attention-based degrading of latents | 642 | # self-attention-based degrading of latents |
| 566 | degraded_latents = self.sag_masking( | 643 | degraded_latents = self.sag_masking( |
| 567 | pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) | 644 | pred_x0, |
| 645 | uncond_attn, | ||
| 646 | t, | ||
| 647 | self.pred_epsilon(latents, noise_pred_uncond, t), | ||
| 568 | ) | 648 | ) |
| 569 | uncond_emb, _ = prompt_embeds.chunk(2) | 649 | uncond_emb, _ = prompt_embeds.chunk(2) |
| 570 | # forward and give guidance | 650 | # forward and give guidance |
| 571 | degraded_pred = self.unet( | 651 | degraded_pred = self.unet( |
| 572 | degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] | 652 | degraded_latents, |
| 653 | t, | ||
| 654 | encoder_hidden_states=uncond_emb, | ||
| 655 | return_dict=False, | ||
| 656 | )[0] | ||
| 573 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) | 657 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) |
| 574 | else: | 658 | else: |
| 575 | # DDIM-like prediction of x0 | 659 | # DDIM-like prediction of x0 |
| @@ -578,18 +662,29 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 578 | cond_attn = store_processor.attention_probs | 662 | cond_attn = store_processor.attention_probs |
| 579 | # self-attention-based degrading of latents | 663 | # self-attention-based degrading of latents |
| 580 | degraded_latents = self.sag_masking( | 664 | degraded_latents = self.sag_masking( |
| 581 | pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) | 665 | pred_x0, |
| 666 | cond_attn, | ||
| 667 | t, | ||
| 668 | self.pred_epsilon(latents, noise_pred, t), | ||
| 582 | ) | 669 | ) |
| 583 | # forward and give guidance | 670 | # forward and give guidance |
| 584 | degraded_pred = self.unet( | 671 | degraded_pred = self.unet( |
| 585 | degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] | 672 | degraded_latents, |
| 673 | t, | ||
| 674 | encoder_hidden_states=prompt_embeds, | ||
| 675 | return_dict=False, | ||
| 676 | )[0] | ||
| 586 | noise_pred += sag_scale * (noise_pred - degraded_pred) | 677 | noise_pred += sag_scale * (noise_pred - degraded_pred) |
| 587 | 678 | ||
| 588 | # compute the previous noisy sample x_t -> x_t-1 | 679 | # compute the previous noisy sample x_t -> x_t-1 |
| 589 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | 680 | latents = self.scheduler.step( |
| 681 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False | ||
| 682 | )[0] | ||
| 590 | 683 | ||
| 591 | # call the callback, if provided | 684 | # call the callback, if provided |
| 592 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | 685 | if i == len(timesteps) - 1 or ( |
| 686 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | ||
| 687 | ): | ||
| 593 | progress_bar.update() | 688 | progress_bar.update() |
| 594 | if callback is not None and i % callback_steps == 0: | 689 | if callback is not None and i % callback_steps == 0: |
| 595 | callback(i, t, latents) | 690 | callback(i, t, latents) |
| @@ -615,7 +710,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 615 | if not return_dict: | 710 | if not return_dict: |
| 616 | return (image, has_nsfw_concept) | 711 | return (image, has_nsfw_concept) |
| 617 | 712 | ||
| 618 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | 713 | return StableDiffusionPipelineOutput( |
| 714 | images=image, nsfw_content_detected=has_nsfw_concept | ||
| 715 | ) | ||
| 619 | 716 | ||
| 620 | # Self-Attention-Guided (SAG) Stable Diffusion | 717 | # Self-Attention-Guided (SAG) Stable Diffusion |
| 621 | 718 | ||
| @@ -632,16 +729,23 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 632 | attn_map = attn_map.reshape(b, h, hw1, hw2) | 729 | attn_map = attn_map.reshape(b, h, hw1, hw2) |
| 633 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 | 730 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 |
| 634 | attn_mask = ( | 731 | attn_mask = ( |
| 635 | attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) | 732 | attn_mask.reshape(b, map_size, map_size) |
| 733 | .unsqueeze(1) | ||
| 734 | .repeat(1, latent_channel, 1, 1) | ||
| 735 | .type(attn_map.dtype) | ||
| 636 | ) | 736 | ) |
| 637 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) | 737 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) |
| 638 | 738 | ||
| 639 | # Blur according to the self-attention mask | 739 | # Blur according to the self-attention mask |
| 640 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) | 740 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) |
| 641 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) | 741 | degraded_latents = degraded_latents * attn_mask + original_latents * ( |
| 742 | 1 - attn_mask | ||
| 743 | ) | ||
| 642 | 744 | ||
| 643 | # Noise it again to match the noise level | 745 | # Noise it again to match the noise level |
| 644 | degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) | 746 | degraded_latents = self.scheduler.add_noise( |
| 747 | degraded_latents, noise=eps, timesteps=t | ||
| 748 | ) | ||
| 645 | 749 | ||
| 646 | return degraded_latents | 750 | return degraded_latents |
| 647 | 751 | ||
| @@ -652,13 +756,19 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 652 | 756 | ||
| 653 | beta_prod_t = 1 - alpha_prod_t | 757 | beta_prod_t = 1 - alpha_prod_t |
| 654 | if self.scheduler.config.prediction_type == "epsilon": | 758 | if self.scheduler.config.prediction_type == "epsilon": |
| 655 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | 759 | pred_original_sample = ( |
| 760 | sample - beta_prod_t ** (0.5) * model_output | ||
| 761 | ) / alpha_prod_t ** (0.5) | ||
| 656 | elif self.scheduler.config.prediction_type == "sample": | 762 | elif self.scheduler.config.prediction_type == "sample": |
| 657 | pred_original_sample = model_output | 763 | pred_original_sample = model_output |
| 658 | elif self.scheduler.config.prediction_type == "v_prediction": | 764 | elif self.scheduler.config.prediction_type == "v_prediction": |
| 659 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | 765 | pred_original_sample = (alpha_prod_t**0.5) * sample - ( |
| 766 | beta_prod_t**0.5 | ||
| 767 | ) * model_output | ||
| 660 | # predict V | 768 | # predict V |
| 661 | model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | 769 | model_output = (alpha_prod_t**0.5) * model_output + ( |
| 770 | beta_prod_t**0.5 | ||
| 771 | ) * sample | ||
| 662 | else: | 772 | else: |
| 663 | raise ValueError( | 773 | raise ValueError( |
| 664 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | 774 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," |
| @@ -674,9 +784,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 674 | if self.scheduler.config.prediction_type == "epsilon": | 784 | if self.scheduler.config.prediction_type == "epsilon": |
| 675 | pred_eps = model_output | 785 | pred_eps = model_output |
| 676 | elif self.scheduler.config.prediction_type == "sample": | 786 | elif self.scheduler.config.prediction_type == "sample": |
| 677 | pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) | 787 | pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / ( |
| 788 | beta_prod_t**0.5 | ||
| 789 | ) | ||
| 678 | elif self.scheduler.config.prediction_type == "v_prediction": | 790 | elif self.scheduler.config.prediction_type == "v_prediction": |
| 679 | pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output | 791 | pred_eps = (beta_prod_t**0.5) * sample + ( |
| 792 | alpha_prod_t**0.5 | ||
| 793 | ) * model_output | ||
| 680 | else: | 794 | else: |
| 681 | raise ValueError( | 795 | raise ValueError( |
| 682 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | 796 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," |
