From 5d850b4893fdb3710a32158879c123b3c411d7e7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 17 Mar 2023 15:38:57 +0100 Subject: Fix loss=nan --- training/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 2d6553a..1baf9c6 100644 --- a/training/functional.py +++ b/training/functional.py @@ -314,7 +314,7 @@ def loss_step( target = noise_scheduler.get_velocity(latents, noise, timesteps) snr = target / (1 - target) snr /= snr + 1 - snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device)) + snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -334,7 +334,7 @@ def loss_step( else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = (snr_weights * loss).mean() + loss = (snr_weights * loss).mean([1, 2, 3]).mean() acc = (model_pred == target).float().mean() return loss, acc, bsz -- cgit v1.2.3-70-g09d2