From 358874cd2c49cb55676af86d2950b86d9ccb023a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 10 Dec 2022 13:12:37 +0100 Subject: Support attention_mask of text encoder --- dreambooth.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 0044c1e..1ef5156 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -723,13 +723,14 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - input_ids = prompt_processor.unify_input_ids(input_ids) + inputs = prompt_processor.unify_input_ids(input_ids) batch = { "prompts": prompts, "nprompts": nprompts, - "input_ids": input_ids, + "input_ids": inputs.input_ids, "pixel_values": pixel_values, + "attention_mask": inputs.attention_mask, } return batch @@ -935,7 +936,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -1047,7 +1048,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample -- cgit v1.2.3-54-g00ecf