summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/functional.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py
index 1baf9c6..27527ef 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -309,10 +309,13 @@ def loss_step(
309 # Get the target for loss depending on the prediction type 309 # Get the target for loss depending on the prediction type
310 if noise_scheduler.config.prediction_type == "epsilon": 310 if noise_scheduler.config.prediction_type == "epsilon":
311 target = noise 311 target = noise
312
312 snr_weights = 1 313 snr_weights = 1
313 elif noise_scheduler.config.prediction_type == "v_prediction": 314 elif noise_scheduler.config.prediction_type == "v_prediction":
314 target = noise_scheduler.get_velocity(latents, noise, timesteps) 315 target = noise_scheduler.get_velocity(latents, noise, timesteps)
315 snr = target / (1 - target) 316
317 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()[..., None, None, None].expand(bsz, 1, 1, 1)
318 snr = alpha_t / (1 - alpha_t)
316 snr /= snr + 1 319 snr /= snr + 1
317 snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) 320 snr_weights = snr.fmin(torch.tensor([5], device=latents.device))
318 else: 321 else: