summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index f80e951..78a34d5 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -194,7 +194,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
194 194
195 return prompt, negative_prompt 195 return prompt, negative_prompt
196 196
197 def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): 197 def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device):
198 text_input_ids = self.prompt_processor.get_input_ids(prompt) 198 text_input_ids = self.prompt_processor.get_input_ids(prompt)
199 text_input_ids *= num_images_per_prompt 199 text_input_ids *= num_images_per_prompt
200 200
@@ -203,8 +203,15 @@ class VlpnStableDiffusion(DiffusionPipeline):
203 unconditional_input_ids *= num_images_per_prompt 203 unconditional_input_ids *= num_images_per_prompt
204 text_input_ids = unconditional_input_ids + text_input_ids 204 text_input_ids = unconditional_input_ids + text_input_ids
205 205
206 text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) 206 text_inputs = self.prompt_processor.unify_input_ids(text_input_ids)
207 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) 207 text_input_ids = text_inputs.input_ids
208
209 if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
210 attention_mask = text_inputs.attention_mask.to(device)
211 else:
212 attention_mask = None
213
214 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask)
208 215
209 return text_embeddings 216 return text_embeddings
210 217
@@ -373,7 +380,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
373 prompt, 380 prompt,
374 negative_prompt, 381 negative_prompt,
375 num_images_per_prompt, 382 num_images_per_prompt,
376 do_classifier_free_guidance 383 do_classifier_free_guidance,
384 device
377 ) 385 )
378 386
379 # 4. Prepare timesteps 387 # 4. Prepare timesteps