From 8abbd633d8ee7500058b2f1f69a6d6611b5a4450 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 17 Mar 2023 15:18:20 +0100 Subject: Test: https://arxiv.org/pdf/2303.09556.pdf --- training/functional.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 4565612..2d6553a 100644 --- a/training/functional.py +++ b/training/functional.py @@ -309,8 +309,12 @@ def loss_step( # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise + snr_weights = 1 elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) + snr = target / (1 - target) + snr /= snr + 1 + snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device)) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -320,16 +324,17 @@ def loss_step( target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = (snr_weights * loss).mean() acc = (model_pred == target).float().mean() return loss, acc, bsz -- cgit v1.2.3-54-g00ecf