From 7ccd4614a56cfd6ecacba85605f338593f1059f0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Feb 2023 20:44:43 +0100 Subject: Add Lora --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3027421..dab7878 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import List, Optional, Union, Callable +from typing import List, Dict, Any, Optional, Union, Callable import numpy as np import torch @@ -337,6 +337,7 @@ class VlpnStableDiffusion(DiffusionPipeline): return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -379,6 +380,10 @@ class VlpnStableDiffusion(DiffusionPipeline): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -450,7 +455,12 @@ class VlpnStableDiffusion(DiffusionPipeline): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: -- cgit v1.2.3-70-g09d2