summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py9
-rw-r--r--models/clip/prompt.py10
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py16
-rw-r--r--textual_inversion.py9
4 files changed, 29 insertions, 15 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 0044c1e..1ef5156 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -723,13 +723,14 @@ def main():
723 pixel_values = torch.stack(pixel_values) 723 pixel_values = torch.stack(pixel_values)
724 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 724 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
725 725
726 input_ids = prompt_processor.unify_input_ids(input_ids) 726 inputs = prompt_processor.unify_input_ids(input_ids)
727 727
728 batch = { 728 batch = {
729 "prompts": prompts, 729 "prompts": prompts,
730 "nprompts": nprompts, 730 "nprompts": nprompts,
731 "input_ids": input_ids, 731 "input_ids": inputs.input_ids,
732 "pixel_values": pixel_values, 732 "pixel_values": pixel_values,
733 "attention_mask": inputs.attention_mask,
733 } 734 }
734 return batch 735 return batch
735 736
@@ -935,7 +936,7 @@ def main():
935 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 936 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
936 937
937 # Get the text embedding for conditioning 938 # Get the text embedding for conditioning
938 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 939 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
939 940
940 # Predict the noise residual 941 # Predict the noise residual
941 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 942 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -1047,7 +1048,7 @@ def main():
1047 1048
1048 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 1049 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1049 1050
1050 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 1051 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
1051 1052
1052 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 1053 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1053 1054
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index 6b6b7e9..9b427a0 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -22,11 +22,15 @@ class PromptProcessor():
22 padding=True, 22 padding=True,
23 pad_to_multiple_of=self.tokenizer.model_max_length, 23 pad_to_multiple_of=self.tokenizer.model_max_length,
24 return_tensors="pt" 24 return_tensors="pt"
25 ).input_ids 25 )
26 26
27 def get_embeddings(self, input_ids: torch.IntTensor): 27 def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None):
28 prompts = input_ids.shape[0] 28 prompts = input_ids.shape[0]
29
29 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 30 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
30 text_embeddings = self.text_encoder(input_ids)[0] 31 if attention_mask is not None:
32 attention_mask = attention_mask.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
33
34 text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0]
31 text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2])) 35 text_embeddings = text_embeddings.reshape((prompts, -1, text_embeddings.shape[2]))
32 return text_embeddings 36 return text_embeddings
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index f80e951..78a34d5 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -194,7 +194,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
194 194
195 return prompt, negative_prompt 195 return prompt, negative_prompt
196 196
197 def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance): 197 def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device):
198 text_input_ids = self.prompt_processor.get_input_ids(prompt) 198 text_input_ids = self.prompt_processor.get_input_ids(prompt)
199 text_input_ids *= num_images_per_prompt 199 text_input_ids *= num_images_per_prompt
200 200
@@ -203,8 +203,15 @@ class VlpnStableDiffusion(DiffusionPipeline):
203 unconditional_input_ids *= num_images_per_prompt 203 unconditional_input_ids *= num_images_per_prompt
204 text_input_ids = unconditional_input_ids + text_input_ids 204 text_input_ids = unconditional_input_ids + text_input_ids
205 205
206 text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) 206 text_inputs = self.prompt_processor.unify_input_ids(text_input_ids)
207 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) 207 text_input_ids = text_inputs.input_ids
208
209 if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
210 attention_mask = text_inputs.attention_mask.to(device)
211 else:
212 attention_mask = None
213
214 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask)
208 215
209 return text_embeddings 216 return text_embeddings
210 217
@@ -373,7 +380,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
373 prompt, 380 prompt,
374 negative_prompt, 381 negative_prompt,
375 num_images_per_prompt, 382 num_images_per_prompt,
376 do_classifier_free_guidance 383 do_classifier_free_guidance,
384 device
377 ) 385 )
378 386
379 # 4. Prepare timesteps 387 # 4. Prepare timesteps
diff --git a/textual_inversion.py b/textual_inversion.py
index 1a5a8d0..da7c747 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -641,13 +641,14 @@ def main():
641 pixel_values = torch.stack(pixel_values) 641 pixel_values = torch.stack(pixel_values)
642 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 642 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
643 643
644 input_ids = prompt_processor.unify_input_ids(input_ids) 644 inputs = prompt_processor.unify_input_ids(input_ids)
645 645
646 batch = { 646 batch = {
647 "prompts": prompts, 647 "prompts": prompts,
648 "nprompts": nprompts, 648 "nprompts": nprompts,
649 "input_ids": input_ids, 649 "input_ids": inputs.input_ids,
650 "pixel_values": pixel_values, 650 "pixel_values": pixel_values,
651 "attention_mask": inputs.attention_mask,
651 } 652 }
652 return batch 653 return batch
653 654
@@ -849,7 +850,7 @@ def main():
849 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 850 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
850 851
851 # Get the text embedding for conditioning 852 # Get the text embedding for conditioning
852 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 853 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
853 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) 854 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
854 855
855 # Predict the noise residual 856 # Predict the noise residual
@@ -948,7 +949,7 @@ def main():
948 949
949 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 950 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
950 951
951 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 952 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
952 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) 953 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
953 954
954 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 955 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample