diff options
| -rw-r--r-- | dreambooth.py | 9 | ||||
| -rw-r--r-- | models/clip/prompt.py | 10 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 | ||||
| -rw-r--r-- | textual_inversion.py | 9 |
4 files changed, 29 insertions, 15 deletions
diff --git a/dreambooth.py b/dreambooth.py index 0044c1e..1ef5156 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -723,13 +723,14 @@ def main(): | |||
| 723 | pixel_values = torch.stack(pixel_values) | 723 | pixel_values = torch.stack(pixel_values) |
| 724 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 724 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 725 | 725 | ||
| 726 | input_ids = prompt_processor.unify_input_ids(input_ids) | 726 | inputs = prompt_processor.unify_input_ids(input_ids) |
| 727 | 727 | ||
| 728 | batch = { | 728 | batch = { |
| 729 | "prompts": prompts, | 729 | "prompts": prompts, |
| 730 | "nprompts": nprompts, | 730 | "nprompts": nprompts, |
| 731 | "input_ids": input_ids, | 731 | "input_ids": inputs.input_ids, |
| 732 | "pixel_values": pixel_values, | 732 | "pixel_values": pixel_values, |
| 733 | "attention_mask": inputs.attention_mask, | ||
| 733 | } | 734 | } |
| 734 | return batch | 735 | return batch |
| 735 | 736 | ||
| @@ -935,7 +936,7 @@ def main(): | |||
| 935 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 936 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 936 | 937 | ||
| 937 | # Get the text embedding for conditioning | 938 | # Get the text embedding for conditioning |
| 938 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 939 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
| 939 | 940 | ||
| 940 | # Predict the noise residual | 941 | # Predict the noise residual |
| 941 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 942 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| @@ -1047,7 +1048,7 @@ def main(): | |||
| 1047 | 1048 | ||
| 1048 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 1049 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 1049 | 1050 | ||
| 1050 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 1051 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
| 1051 | 1052 | ||
| 1052 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 1053 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 1053 | 1054 | ||
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 6b6b7e9..9b427a0 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
| @@ -22,11 +22,15 @@ class PromptProcessor(): | |||
| 22 | padding=True, | 22 | padding=True, |
| 23 | pad_to_multiple_of=self.tokenizer.model_max_length, | 23 | pad_to_multiple_of=self.tokenizer.model_max_length, |
| 24 | return_tensors="pt" | 24 | return_tensors="pt" |
| 25 | ).input_ids | 25 | ) |
| 26 | 26 | ||
| 27 | def get_embeddings(self, input_ids: torch.IntTensor): | 27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): |
| 28 | prompts = input_ids.shape[0] | 28 | prompts = input_ids.shape[0] |
| 29 | |||
| 29 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 30 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
| 30 | text_embeddings = self.text_encoder(input_ids)[0] | 31 | if attention_mask is not None: |
| 32 | attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
| 33 | |||
| 34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] | ||
| 31 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) | 35 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) |
| 32 | return text_embeddings | 36 | return text_embeddings |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f80e951..78a34d5 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -194,7 +194,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 194 | 194 | ||
| 195 | return prompt, negative_prompt | 195 | return prompt, negative_prompt |
| 196 | 196 | ||
| 197 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): | 197 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): |
| 198 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | 198 | text_input_ids = self.prompt_processor.get_input_ids(prompt) |
| 199 | text_input_ids *= num_images_per_prompt | 199 | text_input_ids *= num_images_per_prompt |
| 200 | 200 | ||
| @@ -203,8 +203,15 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 203 | unconditional_input_ids *= num_images_per_prompt | 203 | unconditional_input_ids *= num_images_per_prompt |
| 204 | text_input_ids = unconditional_input_ids + text_input_ids | 204 | text_input_ids = unconditional_input_ids + text_input_ids |
| 205 | 205 | ||
| 206 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) | 206 | text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) |
| 207 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) | 207 | text_input_ids = text_inputs.input_ids |
| 208 | |||
| 209 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | ||
| 210 | attention_mask = text_inputs.attention_mask.to(device) | ||
| 211 | else: | ||
| 212 | attention_mask = None | ||
| 213 | |||
| 214 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) | ||
| 208 | 215 | ||
| 209 | return text_embeddings | 216 | return text_embeddings |
| 210 | 217 | ||
| @@ -373,7 +380,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 373 | prompt, | 380 | prompt, |
| 374 | negative_prompt, | 381 | negative_prompt, |
| 375 | num_images_per_prompt, | 382 | num_images_per_prompt, |
| 376 | do_classifier_free_guidance | 383 | do_classifier_free_guidance, |
| 384 | device | ||
| 377 | ) | 385 | ) |
| 378 | 386 | ||
| 379 | # 4. Prepare timesteps | 387 | # 4. Prepare timesteps |
diff --git a/textual_inversion.py b/textual_inversion.py index 1a5a8d0..da7c747 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -641,13 +641,14 @@ def main(): | |||
| 641 | pixel_values = torch.stack(pixel_values) | 641 | pixel_values = torch.stack(pixel_values) |
| 642 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 642 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 643 | 643 | ||
| 644 | input_ids = prompt_processor.unify_input_ids(input_ids) | 644 | inputs = prompt_processor.unify_input_ids(input_ids) |
| 645 | 645 | ||
| 646 | batch = { | 646 | batch = { |
| 647 | "prompts": prompts, | 647 | "prompts": prompts, |
| 648 | "nprompts": nprompts, | 648 | "nprompts": nprompts, |
| 649 | "input_ids": input_ids, | 649 | "input_ids": inputs.input_ids, |
| 650 | "pixel_values": pixel_values, | 650 | "pixel_values": pixel_values, |
| 651 | "attention_mask": inputs.attention_mask, | ||
| 651 | } | 652 | } |
| 652 | return batch | 653 | return batch |
| 653 | 654 | ||
| @@ -849,7 +850,7 @@ def main(): | |||
| 849 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 850 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 850 | 851 | ||
| 851 | # Get the text embedding for conditioning | 852 | # Get the text embedding for conditioning |
| 852 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 853 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
| 853 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 854 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
| 854 | 855 | ||
| 855 | # Predict the noise residual | 856 | # Predict the noise residual |
| @@ -948,7 +949,7 @@ def main(): | |||
| 948 | 949 | ||
| 949 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 950 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 950 | 951 | ||
| 951 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 952 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
| 952 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 953 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
| 953 | 954 | ||
| 954 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 955 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
