From 358874cd2c49cb55676af86d2950b86d9ccb023a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 10 Dec 2022 13:12:37 +0100 Subject: Support attention_mask of text encoder --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'pipelines/stable_diffusion') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f80e951..78a34d5 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -194,7 +194,7 @@ class VlpnStableDiffusion(DiffusionPipeline): return prompt, negative_prompt - def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): + def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): text_input_ids = self.prompt_processor.get_input_ids(prompt) text_input_ids *= num_images_per_prompt @@ -203,8 +203,15 @@ class VlpnStableDiffusion(DiffusionPipeline): unconditional_input_ids *= num_images_per_prompt text_input_ids = unconditional_input_ids + text_input_ids - text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) - text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) + text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) + text_input_ids = text_inputs.input_ids + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) return text_embeddings @@ -373,7 +380,8 @@ class VlpnStableDiffusion(DiffusionPipeline): prompt, negative_prompt, num_images_per_prompt, - do_classifier_free_guidance + do_classifier_free_guidance, + device ) # 4. Prepare timesteps -- cgit v1.2.3-54-g00ecf