From 3396ca881ed3f3521617cd9024eea56975191d32 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 13:26:32 +0100 Subject: Update --- training/common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'training') diff --git a/training/common.py b/training/common.py index 99a6e67..ab2741a 100644 --- a/training/common.py +++ b/training/common.py @@ -40,7 +40,10 @@ def run_model( noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + encoder_hidden_states = prompt_processor.get_embeddings( + batch["input_ids"], + batch["attention_mask"] + ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual -- cgit v1.2.3-54-g00ecf