summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py262
1 files changed, 188 insertions, 74 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index aa446ec..16b8456 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -21,7 +21,9 @@ from diffusers import (
21 LMSDiscreteScheduler, 21 LMSDiscreteScheduler,
22 PNDMScheduler, 22 PNDMScheduler,
23) 23)
24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
25 StableDiffusionPipelineOutput,
26)
25from diffusers.utils import logging, randn_tensor 27from diffusers.utils import logging, randn_tensor
26from transformers import CLIPTextModel, CLIPTokenizer 28from transformers import CLIPTextModel, CLIPTokenizer
27 29
@@ -62,13 +64,35 @@ def gaussian_blur_2d(img, kernel_size, sigma):
62 return img 64 return img
63 65
64 66
67def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
68 """
69 Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
70 Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
71 """
72 std_text = noise_pred_text.std(
73 dim=list(range(1, noise_pred_text.ndim)), keepdim=True
74 )
75 std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
76 # rescale the results from guidance (fixes overexposure)
77 noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
78 # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
79 noise_cfg = (
80 guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
81 )
82 return noise_cfg
83
84
65class CrossAttnStoreProcessor: 85class CrossAttnStoreProcessor:
66 def __init__(self): 86 def __init__(self):
67 self.attention_probs = None 87 self.attention_probs = None
68 88
69 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): 89 def __call__(
90 self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None
91 ):
70 batch_size, sequence_length, _ = hidden_states.shape 92 batch_size, sequence_length, _ = hidden_states.shape
71 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 93 attention_mask = attn.prepare_attention_mask(
94 attention_mask, sequence_length, batch_size
95 )
72 query = attn.to_q(hidden_states) 96 query = attn.to_q(hidden_states)
73 97
74 if encoder_hidden_states is None: 98 if encoder_hidden_states is None:
@@ -113,7 +137,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
113 ): 137 ):
114 super().__init__() 138 super().__init__()
115 139
116 if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 140 if (
141 hasattr(scheduler.config, "steps_offset")
142 and scheduler.config.steps_offset != 1
143 ):
117 warnings.warn( 144 warnings.warn(
118 f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 145 f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
119 f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 146 f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -179,7 +206,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
179 206
180 device = torch.device("cuda") 207 device = torch.device("cuda")
181 208
182 for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: 209 for cpu_offloaded_model in [
210 self.unet,
211 self.text_encoder,
212 self.vae,
213 self.safety_checker,
214 ]:
183 if cpu_offloaded_model is not None: 215 if cpu_offloaded_model is not None:
184 cpu_offload(cpu_offloaded_model, device) 216 cpu_offload(cpu_offloaded_model, device)
185 217
@@ -223,35 +255,47 @@ class VlpnStableDiffusion(DiffusionPipeline):
223 width: int, 255 width: int,
224 height: int, 256 height: int,
225 strength: float, 257 strength: float,
226 callback_steps: Optional[int] 258 callback_steps: Optional[int],
227 ): 259 ):
228 if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): 260 if isinstance(prompt, str) or (
261 isinstance(prompt, list) and isinstance(prompt[0], int)
262 ):
229 prompt = [prompt] 263 prompt = [prompt]
230 264
231 if negative_prompt is None: 265 if negative_prompt is None:
232 negative_prompt = "" 266 negative_prompt = ""
233 267
234 if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): 268 if isinstance(negative_prompt, str) or (
269 isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)
270 ):
235 negative_prompt = [negative_prompt] * len(prompt) 271 negative_prompt = [negative_prompt] * len(prompt)
236 272
237 if not isinstance(prompt, list): 273 if not isinstance(prompt, list):
238 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 274 raise ValueError(
275 f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
276 )
239 277
240 if not isinstance(negative_prompt, list): 278 if not isinstance(negative_prompt, list):
241 raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") 279 raise ValueError(
280 f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}"
281 )
242 282
243 if len(negative_prompt) != len(prompt): 283 if len(negative_prompt) != len(prompt):
244 raise ValueError( 284 raise ValueError(
245 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") 285 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}"
286 )
246 287
247 if strength < 0 or strength > 1: 288 if strength < 0 or strength > 1:
248 raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") 289 raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}")
249 290
250 if height % 8 != 0 or width % 8 != 0: 291 if height % 8 != 0 or width % 8 != 0:
251 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 292 raise ValueError(
293 f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
294 )
252 295
253 if (callback_steps is None) or ( 296 if (callback_steps is None) or (
254 callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 297 callback_steps is not None
298 and (not isinstance(callback_steps, int) or callback_steps <= 0)
255 ): 299 ):
256 raise ValueError( 300 raise ValueError(
257 f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 301 f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
@@ -266,7 +310,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
266 negative_prompt: Union[List[str], List[List[int]]], 310 negative_prompt: Union[List[str], List[List[int]]],
267 num_images_per_prompt: int, 311 num_images_per_prompt: int,
268 do_classifier_free_guidance: bool, 312 do_classifier_free_guidance: bool,
269 device 313 device,
270 ): 314 ):
271 if isinstance(prompt[0], str): 315 if isinstance(prompt[0], str):
272 text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids 316 text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids
@@ -277,7 +321,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
277 321
278 if do_classifier_free_guidance: 322 if do_classifier_free_guidance:
279 if isinstance(prompt[0], str): 323 if isinstance(prompt[0], str):
280 unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids 324 unconditional_input_ids = self.tokenizer(
325 negative_prompt, padding="do_not_pad"
326 ).input_ids
281 else: 327 else:
282 unconditional_input_ids = negative_prompt 328 unconditional_input_ids = negative_prompt
283 unconditional_input_ids *= num_images_per_prompt 329 unconditional_input_ids *= num_images_per_prompt
@@ -286,12 +332,17 @@ class VlpnStableDiffusion(DiffusionPipeline):
286 text_inputs = unify_input_ids(self.tokenizer, text_input_ids) 332 text_inputs = unify_input_ids(self.tokenizer, text_input_ids)
287 text_input_ids = text_inputs.input_ids 333 text_input_ids = text_inputs.input_ids
288 334
289 if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 335 if (
336 hasattr(self.text_encoder.config, "use_attention_mask")
337 and self.text_encoder.config.use_attention_mask
338 ):
290 attention_mask = text_inputs.attention_mask.to(device) 339 attention_mask = text_inputs.attention_mask.to(device)
291 else: 340 else:
292 attention_mask = None 341 attention_mask = None
293 342
294 prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) 343 prompt_embeds = get_extended_embeddings(
344 self.text_encoder, text_input_ids.to(device), attention_mask
345 )
295 prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 346 prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
296 347
297 return prompt_embeds 348 return prompt_embeds
@@ -301,25 +352,21 @@ class VlpnStableDiffusion(DiffusionPipeline):
301 init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 352 init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
302 353
303 t_start = max(num_inference_steps - init_timestep, 0) 354 t_start = max(num_inference_steps - init_timestep, 0)
304 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] 355 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
305 356
306 timesteps = timesteps.to(device) 357 timesteps = timesteps.to(device)
307 358
308 return timesteps, num_inference_steps - t_start 359 return timesteps, num_inference_steps - t_start
309 360
310 def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): 361 def prepare_latents_from_image(
311 offset_image = perlin_noise( 362 self,
312 (batch_size, 1, width, height), 363 init_image,
313 res=1, 364 timestep,
314 generator=generator, 365 batch_size,
315 dtype=dtype, 366 dtype,
316 device=device 367 device,
317 ) 368 generator=None,
318 offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) 369 ):
319 offset_latents = self.vae.config.scaling_factor * offset_latents
320 return offset_latents
321
322 def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None):
323 init_image = init_image.to(device=device, dtype=dtype) 370 init_image = init_image.to(device=device, dtype=dtype)
324 latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) 371 latents = self.vae.encode(init_image).latent_dist.sample(generator=generator)
325 latents = self.vae.config.scaling_factor * latents 372 latents = self.vae.config.scaling_factor * latents
@@ -333,20 +380,32 @@ class VlpnStableDiffusion(DiffusionPipeline):
333 latents = torch.cat([latents] * batch_multiplier, dim=0) 380 latents = torch.cat([latents] * batch_multiplier, dim=0)
334 381
335 # add noise to latents using the timesteps 382 # add noise to latents using the timesteps
336 noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) 383 noise = torch.randn(
337 384 latents.shape, generator=generator, device=device, dtype=dtype
338 if brightness_offset != 0: 385 )
339 noise += brightness_offset * self.prepare_brightness_offset(
340 batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator
341 )
342 386
343 # get latents 387 # get latents
344 latents = self.scheduler.add_noise(latents, noise, timestep) 388 latents = self.scheduler.add_noise(latents, noise, timestep)
345 389
346 return latents 390 return latents
347 391
348 def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): 392 def prepare_latents(
349 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 393 self,
394 batch_size,
395 num_channels_latents,
396 height,
397 width,
398 dtype,
399 device,
400 generator,
401 latents=None,
402 ):
403 shape = (
404 batch_size,
405 num_channels_latents,
406 height // self.vae_scale_factor,
407 width // self.vae_scale_factor,
408 )
350 if isinstance(generator, list) and len(generator) != batch_size: 409 if isinstance(generator, list) and len(generator) != batch_size:
351 raise ValueError( 410 raise ValueError(
352 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 411 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -354,15 +413,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
354 ) 413 )
355 414
356 if latents is None: 415 if latents is None:
357 latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 416 latents = randn_tensor(
417 shape, generator=generator, device=device, dtype=dtype
418 )
358 else: 419 else:
359 latents = latents.to(device) 420 latents = latents.to(device)
360 421
361 if brightness_offset != 0:
362 latents += brightness_offset * self.prepare_brightness_offset(
363 batch_size, height, width, dtype, device, generator
364 )
365
366 # scale the initial noise by the standard deviation required by the scheduler 422 # scale the initial noise by the standard deviation required by the scheduler
367 latents = latents * self.scheduler.init_noise_sigma 423 latents = latents * self.scheduler.init_noise_sigma
368 return latents 424 return latents
@@ -373,13 +429,17 @@ class VlpnStableDiffusion(DiffusionPipeline):
373 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 429 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
374 # and should be between [0, 1] 430 # and should be between [0, 1]
375 431
376 accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 432 accepts_eta = "eta" in set(
433 inspect.signature(self.scheduler.step).parameters.keys()
434 )
377 extra_step_kwargs = {} 435 extra_step_kwargs = {}
378 if accepts_eta: 436 if accepts_eta:
379 extra_step_kwargs["eta"] = eta 437 extra_step_kwargs["eta"] = eta
380 438
381 # check if the scheduler accepts generator 439 # check if the scheduler accepts generator
382 accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 440 accepts_generator = "generator" in set(
441 inspect.signature(self.scheduler.step).parameters.keys()
442 )
383 if accepts_generator: 443 if accepts_generator:
384 extra_step_kwargs["generator"] = generator 444 extra_step_kwargs["generator"] = generator
385 return extra_step_kwargs 445 return extra_step_kwargs
@@ -396,7 +456,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
396 def __call__( 456 def __call__(
397 self, 457 self,
398 prompt: Union[str, List[str], List[int], List[List[int]]], 458 prompt: Union[str, List[str], List[int], List[List[int]]],
399 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, 459 negative_prompt: Optional[
460 Union[str, List[str], List[int], List[List[int]]]
461 ] = None,
400 num_images_per_prompt: int = 1, 462 num_images_per_prompt: int = 1,
401 strength: float = 1.0, 463 strength: float = 1.0,
402 height: Optional[int] = None, 464 height: Optional[int] = None,
@@ -407,12 +469,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
407 eta: float = 0.0, 469 eta: float = 0.0,
408 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 470 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
409 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 471 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
410 brightness_offset: Union[float, torch.FloatTensor] = 0,
411 output_type: str = "pil", 472 output_type: str = "pil",
412 return_dict: bool = True, 473 return_dict: bool = True,
413 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 474 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
414 callback_steps: int = 1, 475 callback_steps: int = 1,
415 cross_attention_kwargs: Optional[Dict[str, Any]] = None, 476 cross_attention_kwargs: Optional[Dict[str, Any]] = None,
477 guidance_rescale: float = 0.0,
416 ): 478 ):
417 r""" 479 r"""
418 Function invoked when calling the pipeline for generation. 480 Function invoked when calling the pipeline for generation.
@@ -472,7 +534,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
472 width = width or self.unet.config.sample_size * self.vae_scale_factor 534 width = width or self.unet.config.sample_size * self.vae_scale_factor
473 535
474 # 1. Check inputs. Raise error if not correct 536 # 1. Check inputs. Raise error if not correct
475 prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) 537 prompt, negative_prompt = self.check_inputs(
538 prompt, negative_prompt, width, height, strength, callback_steps
539 )
476 540
477 # 2. Define call parameters 541 # 2. Define call parameters
478 batch_size = len(prompt) 542 batch_size = len(prompt)
@@ -488,7 +552,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
488 negative_prompt, 552 negative_prompt,
489 num_images_per_prompt, 553 num_images_per_prompt,
490 do_classifier_free_guidance, 554 do_classifier_free_guidance,
491 device 555 device,
492 ) 556 )
493 557
494 # 4. Prepare latent variables 558 # 4. Prepare latent variables
@@ -497,7 +561,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
497 561
498 # 5. Prepare timesteps 562 # 5. Prepare timesteps
499 self.scheduler.set_timesteps(num_inference_steps, device=device) 563 self.scheduler.set_timesteps(num_inference_steps, device=device)
500 timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 564 timesteps, num_inference_steps = self.get_timesteps(
565 num_inference_steps, strength, device
566 )
501 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 567 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
502 568
503 # 6. Prepare latent variables 569 # 6. Prepare latent variables
@@ -506,7 +572,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
506 image, 572 image,
507 latent_timestep, 573 latent_timestep,
508 batch_size * num_images_per_prompt, 574 batch_size * num_images_per_prompt,
509 brightness_offset,
510 prompt_embeds.dtype, 575 prompt_embeds.dtype,
511 device, 576 device,
512 generator, 577 generator,
@@ -517,7 +582,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
517 num_channels_latents, 582 num_channels_latents,
518 height, 583 height,
519 width, 584 width,
520 brightness_offset,
521 prompt_embeds.dtype, 585 prompt_embeds.dtype,
522 device, 586 device,
523 generator, 587 generator,
@@ -530,14 +594,20 @@ class VlpnStableDiffusion(DiffusionPipeline):
530 # 8. Denoising loo 594 # 8. Denoising loo
531 if do_self_attention_guidance: 595 if do_self_attention_guidance:
532 store_processor = CrossAttnStoreProcessor() 596 store_processor = CrossAttnStoreProcessor()
533 self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor 597 self.unet.mid_block.attentions[0].transformer_blocks[
598 0
599 ].attn1.processor = store_processor
534 600
535 num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 601 num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
536 with self.progress_bar(total=num_inference_steps) as progress_bar: 602 with self.progress_bar(total=num_inference_steps) as progress_bar:
537 for i, t in enumerate(timesteps): 603 for i, t in enumerate(timesteps):
538 # expand the latents if we are doing classifier free guidance 604 # expand the latents if we are doing classifier free guidance
539 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 605 latent_model_input = (
540 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 606 torch.cat([latents] * 2) if do_classifier_free_guidance else latents
607 )
608 latent_model_input = self.scheduler.scale_model_input(
609 latent_model_input, t
610 )
541 611
542 # predict the noise residual 612 # predict the noise residual
543 noise_pred = self.unet( 613 noise_pred = self.unet(
@@ -551,7 +621,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
551 # perform guidance 621 # perform guidance
552 if do_classifier_free_guidance: 622 if do_classifier_free_guidance:
553 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 623 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
554 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 624 noise_pred = noise_pred_uncond + guidance_scale * (
625 noise_pred_text - noise_pred_uncond
626 )
627 noise_pred = rescale_noise_cfg(
628 noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
629 )
555 630
556 if do_self_attention_guidance: 631 if do_self_attention_guidance:
557 # classifier-free guidance produces two chunks of attention map 632 # classifier-free guidance produces two chunks of attention map
@@ -561,15 +636,24 @@ class VlpnStableDiffusion(DiffusionPipeline):
561 # DDIM-like prediction of x0 636 # DDIM-like prediction of x0
562 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) 637 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
563 # get the stored attention maps 638 # get the stored attention maps
564 uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) 639 uncond_attn, cond_attn = store_processor.attention_probs.chunk(
640 2
641 )
565 # self-attention-based degrading of latents 642 # self-attention-based degrading of latents
566 degraded_latents = self.sag_masking( 643 degraded_latents = self.sag_masking(
567 pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) 644 pred_x0,
645 uncond_attn,
646 t,
647 self.pred_epsilon(latents, noise_pred_uncond, t),
568 ) 648 )
569 uncond_emb, _ = prompt_embeds.chunk(2) 649 uncond_emb, _ = prompt_embeds.chunk(2)
570 # forward and give guidance 650 # forward and give guidance
571 degraded_pred = self.unet( 651 degraded_pred = self.unet(
572 degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] 652 degraded_latents,
653 t,
654 encoder_hidden_states=uncond_emb,
655 return_dict=False,
656 )[0]
573 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) 657 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
574 else: 658 else:
575 # DDIM-like prediction of x0 659 # DDIM-like prediction of x0
@@ -578,18 +662,29 @@ class VlpnStableDiffusion(DiffusionPipeline):
578 cond_attn = store_processor.attention_probs 662 cond_attn = store_processor.attention_probs
579 # self-attention-based degrading of latents 663 # self-attention-based degrading of latents
580 degraded_latents = self.sag_masking( 664 degraded_latents = self.sag_masking(
581 pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) 665 pred_x0,
666 cond_attn,
667 t,
668 self.pred_epsilon(latents, noise_pred, t),
582 ) 669 )
583 # forward and give guidance 670 # forward and give guidance
584 degraded_pred = self.unet( 671 degraded_pred = self.unet(
585 degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] 672 degraded_latents,
673 t,
674 encoder_hidden_states=prompt_embeds,
675 return_dict=False,
676 )[0]
586 noise_pred += sag_scale * (noise_pred - degraded_pred) 677 noise_pred += sag_scale * (noise_pred - degraded_pred)
587 678
588 # compute the previous noisy sample x_t -> x_t-1 679 # compute the previous noisy sample x_t -> x_t-1
589 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 680 latents = self.scheduler.step(
681 noise_pred, t, latents, **extra_step_kwargs, return_dict=False
682 )[0]
590 683
591 # call the callback, if provided 684 # call the callback, if provided
592 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 685 if i == len(timesteps) - 1 or (
686 (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
687 ):
593 progress_bar.update() 688 progress_bar.update()
594 if callback is not None and i % callback_steps == 0: 689 if callback is not None and i % callback_steps == 0:
595 callback(i, t, latents) 690 callback(i, t, latents)
@@ -615,7 +710,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
615 if not return_dict: 710 if not return_dict:
616 return (image, has_nsfw_concept) 711 return (image, has_nsfw_concept)
617 712
618 return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 713 return StableDiffusionPipelineOutput(
714 images=image, nsfw_content_detected=has_nsfw_concept
715 )
619 716
620 # Self-Attention-Guided (SAG) Stable Diffusion 717 # Self-Attention-Guided (SAG) Stable Diffusion
621 718
@@ -632,16 +729,23 @@ class VlpnStableDiffusion(DiffusionPipeline):
632 attn_map = attn_map.reshape(b, h, hw1, hw2) 729 attn_map = attn_map.reshape(b, h, hw1, hw2)
633 attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 730 attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
634 attn_mask = ( 731 attn_mask = (
635 attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) 732 attn_mask.reshape(b, map_size, map_size)
733 .unsqueeze(1)
734 .repeat(1, latent_channel, 1, 1)
735 .type(attn_map.dtype)
636 ) 736 )
637 attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) 737 attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w))
638 738
639 # Blur according to the self-attention mask 739 # Blur according to the self-attention mask
640 degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) 740 degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0)
641 degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) 741 degraded_latents = degraded_latents * attn_mask + original_latents * (
742 1 - attn_mask
743 )
642 744
643 # Noise it again to match the noise level 745 # Noise it again to match the noise level
644 degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) 746 degraded_latents = self.scheduler.add_noise(
747 degraded_latents, noise=eps, timesteps=t
748 )
645 749
646 return degraded_latents 750 return degraded_latents
647 751
@@ -652,13 +756,19 @@ class VlpnStableDiffusion(DiffusionPipeline):
652 756
653 beta_prod_t = 1 - alpha_prod_t 757 beta_prod_t = 1 - alpha_prod_t
654 if self.scheduler.config.prediction_type == "epsilon": 758 if self.scheduler.config.prediction_type == "epsilon":
655 pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 759 pred_original_sample = (
760 sample - beta_prod_t ** (0.5) * model_output
761 ) / alpha_prod_t ** (0.5)
656 elif self.scheduler.config.prediction_type == "sample": 762 elif self.scheduler.config.prediction_type == "sample":
657 pred_original_sample = model_output 763 pred_original_sample = model_output
658 elif self.scheduler.config.prediction_type == "v_prediction": 764 elif self.scheduler.config.prediction_type == "v_prediction":
659 pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 765 pred_original_sample = (alpha_prod_t**0.5) * sample - (
766 beta_prod_t**0.5
767 ) * model_output
660 # predict V 768 # predict V
661 model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 769 model_output = (alpha_prod_t**0.5) * model_output + (
770 beta_prod_t**0.5
771 ) * sample
662 else: 772 else:
663 raise ValueError( 773 raise ValueError(
664 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," 774 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
@@ -674,9 +784,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
674 if self.scheduler.config.prediction_type == "epsilon": 784 if self.scheduler.config.prediction_type == "epsilon":
675 pred_eps = model_output 785 pred_eps = model_output
676 elif self.scheduler.config.prediction_type == "sample": 786 elif self.scheduler.config.prediction_type == "sample":
677 pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) 787 pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (
788 beta_prod_t**0.5
789 )
678 elif self.scheduler.config.prediction_type == "v_prediction": 790 elif self.scheduler.config.prediction_type == "v_prediction":
679 pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output 791 pred_eps = (beta_prod_t**0.5) * sample + (
792 alpha_prod_t**0.5
793 ) * model_output
680 else: 794 else:
681 raise ValueError( 795 raise ValueError(
682 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," 796 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"