summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-10 13:12:37 +0100
committerVolpeon <git@volpeon.ink>2022-12-10 13:12:37 +0100
commit358874cd2c49cb55676af86d2950b86d9ccb023a (patch)
tree786239d45944bab1f2af8e24165fc5d5054617f3 /dreambooth.py
parentVarious updated; shuffle prompt content during training (diff)
downloadtextual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.gz
textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.bz2
textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.zip
Support attention_mask of text encoder
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 0044c1e..1ef5156 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -723,13 +723,14 @@ def main():
723 pixel_values = torch.stack(pixel_values) 723 pixel_values = torch.stack(pixel_values)
724 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 724 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
725 725
726 input_ids = prompt_processor.unify_input_ids(input_ids) 726 inputs = prompt_processor.unify_input_ids(input_ids)
727 727
728 batch = { 728 batch = {
729 "prompts": prompts, 729 "prompts": prompts,
730 "nprompts": nprompts, 730 "nprompts": nprompts,
731 "input_ids": input_ids, 731 "input_ids": inputs.input_ids,
732 "pixel_values": pixel_values, 732 "pixel_values": pixel_values,
733 "attention_mask": inputs.attention_mask,
733 } 734 }
734 return batch 735 return batch
735 736
@@ -935,7 +936,7 @@ def main():
935 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 936 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
936 937
937 # Get the text embedding for conditioning 938 # Get the text embedding for conditioning
938 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 939 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
939 940
940 # Predict the noise residual 941 # Predict the noise residual
941 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 942 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -1047,7 +1048,7 @@ def main():
1047 1048
1048 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 1049 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1049 1050
1050 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 1051 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
1051 1052
1052 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 1053 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1053 1054