diff options
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 1a5a8d0..da7c747 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -641,13 +641,14 @@ def main(): | |||
| 641 | pixel_values = torch.stack(pixel_values) | 641 | pixel_values = torch.stack(pixel_values) |
| 642 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 642 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 643 | 643 | ||
| 644 | input_ids = prompt_processor.unify_input_ids(input_ids) | 644 | inputs = prompt_processor.unify_input_ids(input_ids) |
| 645 | 645 | ||
| 646 | batch = { | 646 | batch = { |
| 647 | "prompts": prompts, | 647 | "prompts": prompts, |
| 648 | "nprompts": nprompts, | 648 | "nprompts": nprompts, |
| 649 | "input_ids": input_ids, | 649 | "input_ids": inputs.input_ids, |
| 650 | "pixel_values": pixel_values, | 650 | "pixel_values": pixel_values, |
| 651 | "attention_mask": inputs.attention_mask, | ||
| 651 | } | 652 | } |
| 652 | return batch | 653 | return batch |
| 653 | 654 | ||
| @@ -849,7 +850,7 @@ def main(): | |||
| 849 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 850 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 850 | 851 | ||
| 851 | # Get the text embedding for conditioning | 852 | # Get the text embedding for conditioning |
| 852 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 853 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
| 853 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 854 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
| 854 | 855 | ||
| 855 | # Predict the noise residual | 856 | # Predict the noise residual |
| @@ -948,7 +949,7 @@ def main(): | |||
| 948 | 949 | ||
| 949 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 950 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 950 | 951 | ||
| 951 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 952 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
| 952 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | 953 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) |
| 953 | 954 | ||
| 954 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 955 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
