From 6c38d0088ece492696a7bc94a5cb43a48289452a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:35:42 +0100 Subject: Fix --- training/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'training') diff --git a/training/common.py b/training/common.py index f5ab326..8083137 100644 --- a/training/common.py +++ b/training/common.py @@ -184,7 +184,7 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if batch["with_prior"]: + if batch["with_prior"].all(): # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) -- cgit v1.2.3-54-g00ecf