summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-10 13:12:37 +0100
committerVolpeon <git@volpeon.ink>2022-12-10 13:12:37 +0100
commit358874cd2c49cb55676af86d2950b86d9ccb023a (patch)
tree786239d45944bab1f2af8e24165fc5d5054617f3 /pipelines
parentVarious updated; shuffle prompt content during training (diff)
downloadtextual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.gz
textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.bz2
textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.zip
Support attention_mask of text encoder
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index f80e951..78a34d5 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -194,7 +194,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
194 194
195 return prompt, negative_prompt 195 return prompt, negative_prompt
196 196
197 def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): 197 def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device):
198 text_input_ids = self.prompt_processor.get_input_ids(prompt) 198 text_input_ids = self.prompt_processor.get_input_ids(prompt)
199 text_input_ids *= num_images_per_prompt 199 text_input_ids *= num_images_per_prompt
200 200
@@ -203,8 +203,15 @@ class VlpnStableDiffusion(DiffusionPipeline):
203 unconditional_input_ids *= num_images_per_prompt 203 unconditional_input_ids *= num_images_per_prompt
204 text_input_ids = unconditional_input_ids + text_input_ids 204 text_input_ids = unconditional_input_ids + text_input_ids
205 205
206 text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) 206 text_inputs = self.prompt_processor.unify_input_ids(text_input_ids)
207 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) 207 text_input_ids = text_inputs.input_ids
208
209 if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
210 attention_mask = text_inputs.attention_mask.to(device)
211 else:
212 attention_mask = None
213
214 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask)
208 215
209 return text_embeddings 216 return text_embeddings
210 217
@@ -373,7 +380,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
373 prompt, 380 prompt,
374 negative_prompt, 381 negative_prompt,
375 num_images_per_prompt, 382 num_images_per_prompt,
376 do_classifier_free_guidance 383 do_classifier_free_guidance,
384 device
377 ) 385 )
378 386
379 # 4. Prepare timesteps 387 # 4. Prepare timesteps