From 358874cd2c49cb55676af86d2950b86d9ccb023a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 10 Dec 2022 13:12:37 +0100 Subject: Support attention_mask of text encoder --- dreambooth.py | 9 +++++---- models/clip/prompt.py | 10 +++++++--- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 ++++++++++++---- textual_inversion.py | 9 +++++---- 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 0044c1e..1ef5156 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -723,13 +723,14 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - input_ids = prompt_processor.unify_input_ids(input_ids) + inputs = prompt_processor.unify_input_ids(input_ids) batch = { "prompts": prompts, "nprompts": nprompts, - "input_ids": input_ids, + "input_ids": inputs.input_ids, "pixel_values": pixel_values, + "attention_mask": inputs.attention_mask, } return batch @@ -935,7 +936,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -1047,7 +1048,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 6b6b7e9..9b427a0 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py @@ -22,11 +22,15 @@ class PromptProcessor(): padding=True, pad_to_multiple_of=self.tokenizer.model_max_length, return_tensors="pt" - ).input_ids + ) - def get_embeddings(self, input_ids: torch.IntTensor): + def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): prompts = input_ids.shape[0] + input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - text_embeddings = self.text_encoder(input_ids)[0] + if attention_mask is not None: + attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) + + text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) return text_embeddings diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f80e951..78a34d5 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -194,7 +194,7 @@ class VlpnStableDiffusion(DiffusionPipeline): return prompt, negative_prompt - def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): + def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): text_input_ids = self.prompt_processor.get_input_ids(prompt) text_input_ids *= num_images_per_prompt @@ -203,8 +203,15 @@ class VlpnStableDiffusion(DiffusionPipeline): unconditional_input_ids *= num_images_per_prompt text_input_ids = unconditional_input_ids + text_input_ids - text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) - text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) + text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) + text_input_ids = text_inputs.input_ids + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) return text_embeddings @@ -373,7 +380,8 @@ class VlpnStableDiffusion(DiffusionPipeline): prompt, negative_prompt, num_images_per_prompt, - do_classifier_free_guidance + do_classifier_free_guidance, + device ) # 4. Prepare timesteps diff --git a/textual_inversion.py b/textual_inversion.py index 1a5a8d0..da7c747 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -641,13 +641,14 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - input_ids = prompt_processor.unify_input_ids(input_ids) + inputs = prompt_processor.unify_input_ids(input_ids) batch = { "prompts": prompts, "nprompts": nprompts, - "input_ids": input_ids, + "input_ids": inputs.input_ids, "pixel_values": pixel_values, + "attention_mask": inputs.attention_mask, } return batch @@ -849,7 +850,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) # Predict the noise residual @@ -948,7 +949,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample -- cgit v1.2.3-70-g09d2