diff options
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3027421..dab7878 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -1,6 +1,6 @@ | |||
1 | import inspect | 1 | import inspect |
2 | import warnings | 2 | import warnings |
3 | from typing import List, Optional, Union, Callable | 3 | from typing import List, Dict, Any, Optional, Union, Callable |
4 | 4 | ||
5 | import numpy as np | 5 | import numpy as np |
6 | import torch | 6 | import torch |
@@ -337,6 +337,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
337 | return_dict: bool = True, | 337 | return_dict: bool = True, |
338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
339 | callback_steps: int = 1, | 339 | callback_steps: int = 1, |
340 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
340 | ): | 341 | ): |
341 | r""" | 342 | r""" |
342 | Function invoked when calling the pipeline for generation. | 343 | Function invoked when calling the pipeline for generation. |
@@ -379,6 +380,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
379 | return_dict (`bool`, *optional*, defaults to `True`): | 380 | return_dict (`bool`, *optional*, defaults to `True`): |
380 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | 381 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
381 | plain tuple. | 382 | plain tuple. |
383 | cross_attention_kwargs (`dict`, *optional*): | ||
384 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under | ||
385 | `self.processor` in | ||
386 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | ||
382 | 387 | ||
383 | Returns: | 388 | Returns: |
384 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | 389 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
@@ -450,7 +455,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
450 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 455 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
451 | 456 | ||
452 | # predict the noise residual | 457 | # predict the noise residual |
453 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 458 | noise_pred = self.unet( |
459 | latent_model_input, | ||
460 | t, | ||
461 | encoder_hidden_states=text_embeddings, | ||
462 | cross_attention_kwargs=cross_attention_kwargs, | ||
463 | ).sample | ||
454 | 464 | ||
455 | # perform guidance | 465 | # perform guidance |
456 | if do_classifier_free_guidance: | 466 | if do_classifier_free_guidance: |