diff options
Diffstat (limited to 'pipelines')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3027421..dab7878 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -1,6 +1,6 @@ | |||
| 1 | import inspect | 1 | import inspect |
| 2 | import warnings | 2 | import warnings |
| 3 | from typing import List, Optional, Union, Callable | 3 | from typing import List, Dict, Any, Optional, Union, Callable |
| 4 | 4 | ||
| 5 | import numpy as np | 5 | import numpy as np |
| 6 | import torch | 6 | import torch |
| @@ -337,6 +337,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 337 | return_dict: bool = True, | 337 | return_dict: bool = True, |
| 338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| 339 | callback_steps: int = 1, | 339 | callback_steps: int = 1, |
| 340 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
| 340 | ): | 341 | ): |
| 341 | r""" | 342 | r""" |
| 342 | Function invoked when calling the pipeline for generation. | 343 | Function invoked when calling the pipeline for generation. |
| @@ -379,6 +380,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 379 | return_dict (`bool`, *optional*, defaults to `True`): | 380 | return_dict (`bool`, *optional*, defaults to `True`): |
| 380 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | 381 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| 381 | plain tuple. | 382 | plain tuple. |
| 383 | cross_attention_kwargs (`dict`, *optional*): | ||
| 384 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under | ||
| 385 | `self.processor` in | ||
| 386 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | ||
| 382 | 387 | ||
| 383 | Returns: | 388 | Returns: |
| 384 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | 389 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
| @@ -450,7 +455,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 450 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 455 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| 451 | 456 | ||
| 452 | # predict the noise residual | 457 | # predict the noise residual |
| 453 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 458 | noise_pred = self.unet( |
| 459 | latent_model_input, | ||
| 460 | t, | ||
| 461 | encoder_hidden_states=text_embeddings, | ||
| 462 | cross_attention_kwargs=cross_attention_kwargs, | ||
| 463 | ).sample | ||
| 454 | 464 | ||
| 455 | # perform guidance | 465 | # perform guidance |
| 456 | if do_classifier_free_guidance: | 466 | if do_classifier_free_guidance: |
