From 358874cd2c49cb55676af86d2950b86d9ccb023a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 10 Dec 2022 13:12:37 +0100 Subject: Support attention_mask of text encoder --- textual_inversion.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 1a5a8d0..da7c747 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -641,13 +641,14 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - input_ids = prompt_processor.unify_input_ids(input_ids) + inputs = prompt_processor.unify_input_ids(input_ids) batch = { "prompts": prompts, "nprompts": nprompts, - "input_ids": input_ids, + "input_ids": inputs.input_ids, "pixel_values": pixel_values, + "attention_mask": inputs.attention_mask, } return batch @@ -849,7 +850,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) # Predict the noise residual @@ -948,7 +949,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample -- cgit v1.2.3-54-g00ecf