summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 3027421..dab7878 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -1,6 +1,6 @@
1import inspect 1import inspect
2import warnings 2import warnings
3from typing import List, Optional, Union, Callable 3from typing import List, Dict, Any, Optional, Union, Callable
4 4
5import numpy as np 5import numpy as np
6import torch 6import torch
@@ -337,6 +337,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
337 return_dict: bool = True, 337 return_dict: bool = True,
338 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 338 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
339 callback_steps: int = 1, 339 callback_steps: int = 1,
340 cross_attention_kwargs: Optional[Dict[str, Any]] = None,
340 ): 341 ):
341 r""" 342 r"""
342 Function invoked when calling the pipeline for generation. 343 Function invoked when calling the pipeline for generation.
@@ -379,6 +380,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
379 return_dict (`bool`, *optional*, defaults to `True`): 380 return_dict (`bool`, *optional*, defaults to `True`):
380 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 381 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
381 plain tuple. 382 plain tuple.
383 cross_attention_kwargs (`dict`, *optional*):
384 A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
385 `self.processor` in
386 [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
382 387
383 Returns: 388 Returns:
384 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 389 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -450,7 +455,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
450 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 455 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
451 456
452 # predict the noise residual 457 # predict the noise residual
453 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 458 noise_pred = self.unet(
459 latent_model_input,
460 t,
461 encoder_hidden_states=text_embeddings,
462 cross_attention_kwargs=cross_attention_kwargs,
463 ).sample
454 464
455 # perform guidance 465 # perform guidance
456 if do_classifier_free_guidance: 466 if do_classifier_free_guidance: