diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/common.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/training/common.py b/training/common.py index 99a6e67..ab2741a 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -40,7 +40,10 @@ def run_model( | |||
40 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 40 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
41 | 41 | ||
42 | # Get the text embedding for conditioning | 42 | # Get the text embedding for conditioning |
43 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 43 | encoder_hidden_states = prompt_processor.get_embeddings( |
44 | batch["input_ids"], | ||
45 | batch["attention_mask"] | ||
46 | ) | ||
44 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 47 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) |
45 | 48 | ||
46 | # Predict the noise residual | 49 | # Predict the noise residual |