summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py
index b9574ec..15b95ba 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -314,9 +314,12 @@ def loss_step(
314 elif noise_scheduler.config.prediction_type == "v_prediction": 314 elif noise_scheduler.config.prediction_type == "v_prediction":
315 target = noise_scheduler.get_velocity(latents, noise, timesteps) 315 target = noise_scheduler.get_velocity(latents, noise, timesteps)
316 316
317 p2_gamma = 1
318 p2_k = 1
319
317 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() 320 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()
318 snr = alpha_t / (1 - alpha_t) 321 snr = 1.0 / (1 - alpha_t) - 1
319 snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) / (snr + 1) 322 snr_weights = 1 / (p2_k + snr) ** p2_gamma
320 snr_weights = snr_weights[..., None, None, None] 323 snr_weights = snr_weights[..., None, None, None]
321 else: 324 else:
322 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 325 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")