summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-17 15:38:57 +0100
committerVolpeon <git@volpeon.ink>2023-03-17 15:38:57 +0100
commit5d850b4893fdb3710a32158879c123b3c411d7e7 (patch)
tree8e7c5f412b2ed1ed54f1043ce2ff14e946d11caa /training/functional.py
parentTest: https://arxiv.org/pdf/2303.09556.pdf (diff)
downloadtextual-inversion-diff-5d850b4893fdb3710a32158879c123b3c411d7e7.tar.gz
textual-inversion-diff-5d850b4893fdb3710a32158879c123b3c411d7e7.tar.bz2
textual-inversion-diff-5d850b4893fdb3710a32158879c123b3c411d7e7.zip
Fix loss=nan
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py
index 2d6553a..1baf9c6 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -314,7 +314,7 @@ def loss_step(
314 target = noise_scheduler.get_velocity(latents, noise, timesteps) 314 target = noise_scheduler.get_velocity(latents, noise, timesteps)
315 snr = target / (1 - target) 315 snr = target / (1 - target)
316 snr /= snr + 1 316 snr /= snr + 1
317 snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device)) 317 snr_weights = snr.fmin(torch.tensor([5], device=latents.device))
318 else: 318 else:
319 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 319 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
320 320
@@ -334,7 +334,7 @@ def loss_step(
334 else: 334 else:
335 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 335 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
336 336
337 loss = (snr_weights * loss).mean() 337 loss = (snr_weights * loss).mean([1, 2, 3]).mean()
338 acc = (model_pred == target).float().mean() 338 acc = (model_pred == target).float().mean()
339 339
340 return loss, acc, bsz 340 return loss, acc, bsz