From 873bff68e5f37c85753df4240f3b2b2b88fc99a7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 18 Mar 2023 22:15:49 +0100 Subject: New loss weighting from arxiv.org:2204.00227 --- training/functional.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index b9574ec..15b95ba 100644 --- a/training/functional.py +++ b/training/functional.py @@ -314,9 +314,12 @@ def loss_step( elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) + p2_gamma = 1 + p2_k = 1 + alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() - snr = alpha_t / (1 - alpha_t) - snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) / (snr + 1) + snr = 1.0 / (1 - alpha_t) - 1 + snr_weights = 1 / (p2_k + snr) ** p2_gamma snr_weights = snr_weights[..., None, None, None] else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") -- cgit v1.2.3-54-g00ecf