summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/common.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/training/common.py b/training/common.py
index f5ab326..8083137 100644
--- a/training/common.py
+++ b/training/common.py
@@ -184,7 +184,7 @@ def loss_step(
184 else: 184 else:
185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
186 186
187 if batch["with_prior"]: 187 if batch["with_prior"].all():
188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
190 target, target_prior = torch.chunk(target, 2, dim=0) 190 target, target_prior = torch.chunk(target, 2, dim=0)