diff options
author | Volpeon <git@volpeon.ink> | 2022-12-10 13:12:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-10 13:12:37 +0100 |
commit | 358874cd2c49cb55676af86d2950b86d9ccb023a (patch) | |
tree | 786239d45944bab1f2af8e24165fc5d5054617f3 /pipelines | |
parent | Various updated; shuffle prompt content during training (diff) | |
download | textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.gz textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.bz2 textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.zip |
Support attention_mask of text encoder
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f80e951..78a34d5 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -194,7 +194,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
194 | 194 | ||
195 | return prompt, negative_prompt | 195 | return prompt, negative_prompt |
196 | 196 | ||
197 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): | 197 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): |
198 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | 198 | text_input_ids = self.prompt_processor.get_input_ids(prompt) |
199 | text_input_ids *= num_images_per_prompt | 199 | text_input_ids *= num_images_per_prompt |
200 | 200 | ||
@@ -203,8 +203,15 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
203 | unconditional_input_ids *= num_images_per_prompt | 203 | unconditional_input_ids *= num_images_per_prompt |
204 | text_input_ids = unconditional_input_ids + text_input_ids | 204 | text_input_ids = unconditional_input_ids + text_input_ids |
205 | 205 | ||
206 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) | 206 | text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) |
207 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) | 207 | text_input_ids = text_inputs.input_ids |
208 | |||
209 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | ||
210 | attention_mask = text_inputs.attention_mask.to(device) | ||
211 | else: | ||
212 | attention_mask = None | ||
213 | |||
214 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) | ||
208 | 215 | ||
209 | return text_embeddings | 216 | return text_embeddings |
210 | 217 | ||
@@ -373,7 +380,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
373 | prompt, | 380 | prompt, |
374 | negative_prompt, | 381 | negative_prompt, |
375 | num_images_per_prompt, | 382 | num_images_per_prompt, |
376 | do_classifier_free_guidance | 383 | do_classifier_free_guidance, |
384 | device | ||
377 | ) | 385 | ) |
378 | 386 | ||
379 | # 4. Prepare timesteps | 387 | # 4. Prepare timesteps |