From 2f3f3644f723f5c1500939c5dfe4cf4da81e4831 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 17 Mar 2023 15:57:46 +0100 Subject: Fixed snr weight calculation --- training/functional.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 1baf9c6..27527ef 100644 --- a/training/functional.py +++ b/training/functional.py @@ -309,10 +309,13 @@ def loss_step( # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise + snr_weights = 1 elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) - snr = target / (1 - target) + + alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()[..., None, None, None].expand(bsz, 1, 1, 1) + snr = alpha_t / (1 - alpha_t) snr /= snr + 1 snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) else: -- cgit v1.2.3-54-g00ecf