diff options
author | Volpeon <git@volpeon.ink> | 2022-12-10 13:12:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-10 13:12:37 +0100 |
commit | 358874cd2c49cb55676af86d2950b86d9ccb023a (patch) | |
tree | 786239d45944bab1f2af8e24165fc5d5054617f3 /dreambooth.py | |
parent | Various updated; shuffle prompt content during training (diff) | |
download | textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.gz textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.bz2 textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.zip |
Support attention_mask of text encoder
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index 0044c1e..1ef5156 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -723,13 +723,14 @@ def main(): | |||
723 | pixel_values = torch.stack(pixel_values) | 723 | pixel_values = torch.stack(pixel_values) |
724 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 724 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
725 | 725 | ||
726 | input_ids = prompt_processor.unify_input_ids(input_ids) | 726 | inputs = prompt_processor.unify_input_ids(input_ids) |
727 | 727 | ||
728 | batch = { | 728 | batch = { |
729 | "prompts": prompts, | 729 | "prompts": prompts, |
730 | "nprompts": nprompts, | 730 | "nprompts": nprompts, |
731 | "input_ids": input_ids, | 731 | "input_ids": inputs.input_ids, |
732 | "pixel_values": pixel_values, | 732 | "pixel_values": pixel_values, |
733 | "attention_mask": inputs.attention_mask, | ||
733 | } | 734 | } |
734 | return batch | 735 | return batch |
735 | 736 | ||
@@ -935,7 +936,7 @@ def main(): | |||
935 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 936 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
936 | 937 | ||
937 | # Get the text embedding for conditioning | 938 | # Get the text embedding for conditioning |
938 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 939 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
939 | 940 | ||
940 | # Predict the noise residual | 941 | # Predict the noise residual |
941 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 942 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -1047,7 +1048,7 @@ def main(): | |||
1047 | 1048 | ||
1048 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 1049 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
1049 | 1050 | ||
1050 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 1051 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
1051 | 1052 | ||
1052 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 1053 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
1053 | 1054 | ||