diff options
author | Volpeon <git@volpeon.ink> | 2023-03-17 15:38:57 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-17 15:38:57 +0100 |
commit | 5d850b4893fdb3710a32158879c123b3c411d7e7 (patch) | |
tree | 8e7c5f412b2ed1ed54f1043ce2ff14e946d11caa | |
parent | Test: https://arxiv.org/pdf/2303.09556.pdf (diff) | |
download | textual-inversion-diff-5d850b4893fdb3710a32158879c123b3c411d7e7.tar.gz textual-inversion-diff-5d850b4893fdb3710a32158879c123b3c411d7e7.tar.bz2 textual-inversion-diff-5d850b4893fdb3710a32158879c123b3c411d7e7.zip |
Fix loss=nan
-rw-r--r-- | training/functional.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py index 2d6553a..1baf9c6 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -314,7 +314,7 @@ def loss_step( | |||
314 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 314 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
315 | snr = target / (1 - target) | 315 | snr = target / (1 - target) |
316 | snr /= snr + 1 | 316 | snr /= snr + 1 |
317 | snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device)) | 317 | snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) |
318 | else: | 318 | else: |
319 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 319 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
320 | 320 | ||
@@ -334,7 +334,7 @@ def loss_step( | |||
334 | else: | 334 | else: |
335 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 335 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
336 | 336 | ||
337 | loss = (snr_weights * loss).mean() | 337 | loss = (snr_weights * loss).mean([1, 2, 3]).mean() |
338 | acc = (model_pred == target).float().mean() | 338 | acc = (model_pred == target).float().mean() |
339 | 339 | ||
340 | return loss, acc, bsz | 340 | return loss, acc, bsz |