diff options
author | Volpeon <git@volpeon.ink> | 2022-12-10 13:12:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-10 13:12:37 +0100 |
commit | 358874cd2c49cb55676af86d2950b86d9ccb023a (patch) | |
tree | 786239d45944bab1f2af8e24165fc5d5054617f3 | |
parent | Various updated; shuffle prompt content during training (diff) | |
download | textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.gz textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.bz2 textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.zip |
Support attention_mask of text encoder
-rw-r--r-- | dreambooth.py | 9 | ||||
-rw-r--r-- | models/clip/prompt.py | 10 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 | ||||
-rw-r--r-- | textual_inversion.py | 9 |
4 files changed, 29 insertions, 15 deletions
diff --git a/dreambooth.py b/dreambooth.py index 0044c1e..1ef5156 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -723,13 +723,14 @@ def main(): | |||
723 | pixel_values = torch.stack(pixel_values) | 723 | pixel_values = torch.stack(pixel_values) |
724 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 724 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
725 | 725 | ||
726 | input_ids = prompt_processor.unify_input_ids(input_ids) | 726 | inputs = prompt_processor.unify_input_ids(input_ids) |
727 | 727 | ||
728 | batch = { | 728 | batch = { |
729 | "prompts": prompts, | 729 | "prompts": prompts, |
730 | "nprompts": nprompts, | 730 | "nprompts": nprompts, |
731 | "input_ids": input_ids, | 731 | "input_ids": inputs.input_ids, |
732 | "pixel_values": pixel_values, | 732 | "pixel_values": pixel_values, |
733 | "attention_mask": inputs.attention_mask, | ||
733 | } | 734 | } |
734 | return batch | 735 | return batch |
735 | 736 | ||
@@ -935,7 +936,7 @@ def main(): | |||
935 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 936 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
936 | 937 | ||
937 | # Get the text embedding for conditioning | 938 | # Get the text embedding for conditioning |
938 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 939 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
939 | 940 | ||
940 | # Predict the noise residual | 941 | # Predict the noise residual |
941 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 942 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -1047,7 +1048,7 @@ def main(): | |||
1047 | 1048 | ||
1048 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 1049 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
1049 | 1050 | ||
1050 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 1051 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
1051 | 1052 | ||
1052 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 1053 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
1053 | 1054 | ||
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 6b6b7e9..9b427a0 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
@@ -22,11 +22,15 @@ class PromptProcessor(): | |||
22 | padding=True, | 22 | padding=True, |
23 | pad_to_multiple_of=self.tokenizer.model_max_length, | 23 | pad_to_multiple_of=self.tokenizer.model_max_length, |
24 | return_tensors="pt" | 24 | return_tensors="pt" |
25 | ).input_ids | 25 | ) |
26 | 26 | ||
27 | def get_embeddings(self, input_ids: torch.IntTensor): | 27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): |
28 | prompts = input_ids.shape[0] | 28 | prompts = input_ids.shape[0] |
29 | |||
29 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 30 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
30 | text_embeddings = self.text_encoder(input_ids)[0] | 31 | if attention_mask is not None: |
32 | attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
33 | |||
34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] | ||
31 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) | 35 | text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) |
32 | return text_embeddings | 36 | return text_embeddings |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f80e951..78a34d5 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -194,7 +194,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
194 | 194 | ||
195 | return prompt, negative_prompt | 195 | return prompt, negative_prompt |
196 | 196 | ||
197 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): | 197 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): |
198 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | 198 | text_input_ids = self.prompt_processor.get_input_ids(prompt) |
199 | text_input_ids *= num_images_per_prompt | 199 | text_input_ids *= num_images_per_prompt |
200 | 200 | ||
@@ -203,8 +203,15 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
203 | unconditional_input_ids *= num_images_per_prompt | 203 | unconditional_input_ids *= num_images_per_prompt |
204 | text_input_ids = unconditional_input_ids + text_input_ids | 204 | text_input_ids = unconditional_input_ids + text_input_ids |
205 | 205 | ||
206 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) | 206 | text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) |
207 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) | 207 | text_input_ids = text_inputs.input_ids |
208 | |||
209 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | ||
210 | attention_mask = text_inputs.attention_mask.to(device) | ||
211 | else: | ||
212 | attention_mask = None | ||
213 | |||
214 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) | ||
208 | 215 | ||
209 | return text_embeddings | 216 | return text_embeddings |
210 | 217 | ||
@@ -373,7 +380,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
373 | prompt, | 380 | prompt, |
374 | negative_prompt, | 381 | negative_prompt, |
375 | num_images_per_prompt, | 382 | num_images_per_prompt, |
376 | do_classifier_free_guidance | 383 | do_classifier_free_guidance, |
384 | device | ||
377 | ) | 385 | ) |
378 | 386 | ||
379 | # 4. Prepare timesteps | 387 | # 4. Prepare timesteps |
diff --git a/textual_inversion.py b/textual_inversion.py index 1a5a8d0..da7c747 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -641,13 +641,14 @@ def main(): | |||
641 | pixel_values = torch.stack(pixel_values) | 641 | pixel_values = torch.stack(pixel_values) |
642 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 642 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
643 | 643 | ||
644 | input_ids = prompt_processor.unify_input_ids(input_ids) | 644 | inputs = prompt_processor.unify_input_ids(input_ids) |
645 | 645 | ||
646 | batch = { | 646 | batch = { |
647 | "prompts": prompts, | 647 | "prompts": prompts, |
648 | "nprompts": nprompts, | 648 | "nprompts": nprompts, |
649 | "input_ids": input_ids, | 649 | "input_ids": inputs.input_ids, |
650 | "pixel_values": pixel_values, | 650 | "pixel_values": pixel_values, |
651 | "attention_mask": inputs.attention_mask, | ||
651 | } | 652 | } |
652 | return batch | 653 | return batch |
653 | 654 | ||
@@ -849,7 +850,7 @@ def main(): | |||
849 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 850 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
850 | 851 | ||
851 | # Get the text embedding for conditioning | 852 | # Get the text embedding for conditioning |
852 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 853 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
853 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 854 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
854 | 855 | ||
855 | # Predict the noise residual | 856 | # Predict the noise residual |
@@ -948,7 +949,7 @@ def main(): | |||
948 | 949 | ||
949 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 950 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
950 | 951 | ||
951 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 952 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
952 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 953 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
953 | 954 | ||
954 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 955 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |