summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/common.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/training/common.py b/training/common.py
index 99a6e67..ab2741a 100644
--- a/training/common.py
+++ b/training/common.py
@@ -40,7 +40,10 @@ def run_model(
40 noisy_latents = noisy_latents.to(dtype=unet.dtype) 40 noisy_latents = noisy_latents.to(dtype=unet.dtype)
41 41
42 # Get the text embedding for conditioning 42 # Get the text embedding for conditioning
43 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) 43 encoder_hidden_states = prompt_processor.get_embeddings(
44 batch["input_ids"],
45 batch["attention_mask"]
46 )
44 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) 47 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
45 48
46 # Predict the noise residual 49 # Predict the noise residual