From 7ccd4614a56cfd6ecacba85605f338593f1059f0 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Tue, 7 Feb 2023 20:44:43 +0100
Subject: Add Lora

---
 pipelines/stable_diffusion/vlpn_stable_diffusion.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

(limited to 'pipelines/stable_diffusion')

diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 3027421..dab7878 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -1,6 +1,6 @@
 import inspect
 import warnings
-from typing import List, Optional, Union, Callable
+from typing import List, Dict, Any, Optional, Union, Callable
 
 import numpy as np
 import torch
@@ -337,6 +337,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
         return_dict: bool = True,
         callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         callback_steps: int = 1,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
     ):
         r"""
         Function invoked when calling the pipeline for generation.
@@ -379,6 +380,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
             return_dict (`bool`, *optional*, defaults to `True`):
                 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                 plain tuple.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
+                `self.processor` in
+                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
 
         Returns:
             [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -450,7 +455,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
                 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
                 # predict the noise residual
-                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=text_embeddings,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                ).sample
 
                 # perform guidance
                 if do_classifier_free_guidance:
-- 
cgit v1.2.3-70-g09d2