summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-10 13:12:37 +0100
committerVolpeon <git@volpeon.ink>2022-12-10 13:12:37 +0100
commit358874cd2c49cb55676af86d2950b86d9ccb023a (patch)
tree786239d45944bab1f2af8e24165fc5d5054617f3 /textual_inversion.py
parentVarious updated; shuffle prompt content during training (diff)
downloadtextual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.gz
textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.tar.bz2
textual-inversion-diff-358874cd2c49cb55676af86d2950b86d9ccb023a.zip
Support attention_mask of text encoder
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 1a5a8d0..da7c747 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -641,13 +641,14 @@ def main():
641 pixel_values = torch.stack(pixel_values) 641 pixel_values = torch.stack(pixel_values)
642 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 642 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
643 643
644 input_ids = prompt_processor.unify_input_ids(input_ids) 644 inputs = prompt_processor.unify_input_ids(input_ids)
645 645
646 batch = { 646 batch = {
647 "prompts": prompts, 647 "prompts": prompts,
648 "nprompts": nprompts, 648 "nprompts": nprompts,
649 "input_ids": input_ids, 649 "input_ids": inputs.input_ids,
650 "pixel_values": pixel_values, 650 "pixel_values": pixel_values,
651 "attention_mask": inputs.attention_mask,
651 } 652 }
652 return batch 653 return batch
653 654
@@ -849,7 +850,7 @@ def main():
849 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 850 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
850 851
851 # Get the text embedding for conditioning 852 # Get the text embedding for conditioning
852 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 853 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
853 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) 854 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
854 855
855 # Predict the noise residual 856 # Predict the noise residual
@@ -948,7 +949,7 @@ def main():
948 949
949 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 950 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
950 951
951 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 952 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
952 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) 953 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
953 954
954 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 955 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample