diff options
author | Volpeon <git@volpeon.ink> | 2022-12-10 13:12:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-10 13:12:37 +0100 |
commit | 358874cd2c49cb55676af86d2950b86d9ccb023a (patch) | |
tree | 786239d45944bab1f2af8e24165fc5d5054617f3 /textual_inversion.py | |
parent | Various updated; shuffle prompt content during training (diff) | |
download | textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.gz textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.bz2 textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.zip |
Support attention_mask of text encoder
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 1a5a8d0..da7c747 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -641,13 +641,14 @@ def main(): | |||
641 | pixel_values = torch.stack(pixel_values) | 641 | pixel_values = torch.stack(pixel_values) |
642 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 642 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
643 | 643 | ||
644 | input_ids = prompt_processor.unify_input_ids(input_ids) | 644 | inputs = prompt_processor.unify_input_ids(input_ids) |
645 | 645 | ||
646 | batch = { | 646 | batch = { |
647 | "prompts": prompts, | 647 | "prompts": prompts, |
648 | "nprompts": nprompts, | 648 | "nprompts": nprompts, |
649 | "input_ids": input_ids, | 649 | "input_ids": inputs.input_ids, |
650 | "pixel_values": pixel_values, | 650 | "pixel_values": pixel_values, |
651 | "attention_mask": inputs.attention_mask, | ||
651 | } | 652 | } |
652 | return batch | 653 | return batch |
653 | 654 | ||
@@ -849,7 +850,7 @@ def main(): | |||
849 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 850 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
850 | 851 | ||
851 | # Get the text embedding for conditioning | 852 | # Get the text embedding for conditioning |
852 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 853 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
853 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 854 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
854 | 855 | ||
855 | # Predict the noise residual | 856 | # Predict the noise residual |
@@ -948,7 +949,7 @@ def main(): | |||
948 | 949 | ||
949 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 950 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
950 | 951 | ||
951 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 952 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
952 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 953 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
953 | 954 | ||
954 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 955 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |