From 8c9dd1a230daf8a662447465c32dcae46ecbbe5f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 18 Mar 2023 09:42:28 +0100 Subject: Better SNR weighting --- training/functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/training/functional.py b/training/functional.py index 27527ef..b9574ec 100644 --- a/training/functional.py +++ b/training/functional.py @@ -314,10 +314,10 @@ def loss_step( elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) - alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()[..., None, None, None].expand(bsz, 1, 1, 1) + alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() snr = alpha_t / (1 - alpha_t) - snr /= snr + 1 - snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) + snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) / (snr + 1) + snr_weights = snr_weights[..., None, None, None] else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") -- cgit v1.2.3-70-g09d2