diff options
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f80e951..78a34d5 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -194,7 +194,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 194 | 194 | ||
| 195 | return prompt, negative_prompt | 195 | return prompt, negative_prompt |
| 196 | 196 | ||
| 197 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): | 197 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): |
| 198 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | 198 | text_input_ids = self.prompt_processor.get_input_ids(prompt) |
| 199 | text_input_ids *= num_images_per_prompt | 199 | text_input_ids *= num_images_per_prompt |
| 200 | 200 | ||
| @@ -203,8 +203,15 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 203 | unconditional_input_ids *= num_images_per_prompt | 203 | unconditional_input_ids *= num_images_per_prompt |
| 204 | text_input_ids = unconditional_input_ids + text_input_ids | 204 | text_input_ids = unconditional_input_ids + text_input_ids |
| 205 | 205 | ||
| 206 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) | 206 | text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) |
| 207 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) | 207 | text_input_ids = text_inputs.input_ids |
| 208 | |||
| 209 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | ||
| 210 | attention_mask = text_inputs.attention_mask.to(device) | ||
| 211 | else: | ||
| 212 | attention_mask = None | ||
| 213 | |||
| 214 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) | ||
| 208 | 215 | ||
| 209 | return text_embeddings | 216 | return text_embeddings |
| 210 | 217 | ||
| @@ -373,7 +380,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 373 | prompt, | 380 | prompt, |
| 374 | negative_prompt, | 381 | negative_prompt, |
| 375 | num_images_per_prompt, | 382 | num_images_per_prompt, |
| 376 | do_classifier_free_guidance | 383 | do_classifier_free_guidance, |
| 384 | device | ||
| 377 | ) | 385 | ) |
| 378 | 386 | ||
| 379 | # 4. Prepare timesteps | 387 | # 4. Prepare timesteps |
