diff options
-rw-r--r-- | training/functional.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py index 1baf9c6..27527ef 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -309,10 +309,13 @@ def loss_step( | |||
309 | # Get the target for loss depending on the prediction type | 309 | # Get the target for loss depending on the prediction type |
310 | if noise_scheduler.config.prediction_type == "epsilon": | 310 | if noise_scheduler.config.prediction_type == "epsilon": |
311 | target = noise | 311 | target = noise |
312 | |||
312 | snr_weights = 1 | 313 | snr_weights = 1 |
313 | elif noise_scheduler.config.prediction_type == "v_prediction": | 314 | elif noise_scheduler.config.prediction_type == "v_prediction": |
314 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 315 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
315 | snr = target / (1 - target) | 316 | |
317 | alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()[..., None, None, None].expand(bsz, 1, 1, 1) | ||
318 | snr = alpha_t / (1 - alpha_t) | ||
316 | snr /= snr + 1 | 319 | snr /= snr + 1 |
317 | snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) | 320 | snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) |
318 | else: | 321 | else: |