From 5d850b4893fdb3710a32158879c123b3c411d7e7 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Fri, 17 Mar 2023 15:38:57 +0100
Subject: Fix loss=nan

---
 training/functional.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/training/functional.py b/training/functional.py
index 2d6553a..1baf9c6 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -314,7 +314,7 @@ def loss_step(
         target = noise_scheduler.get_velocity(latents, noise, timesteps)
         snr = target / (1 - target)
         snr /= snr + 1
-        snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device))
+        snr_weights = snr.fmin(torch.tensor([5], device=latents.device))
     else:
         raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
 
@@ -334,7 +334,7 @@ def loss_step(
     else:
         loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
 
-    loss = (snr_weights * loss).mean()
+    loss = (snr_weights * loss).mean([1, 2, 3]).mean()
     acc = (model_pred == target).float().mean()
 
     return loss, acc, bsz
-- 
cgit v1.2.3-70-g09d2