summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index 4565612..2d6553a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -309,8 +309,12 @@ def loss_step(
309 # Get the target for loss depending on the prediction type 309 # Get the target for loss depending on the prediction type
310 if noise_scheduler.config.prediction_type == "epsilon": 310 if noise_scheduler.config.prediction_type == "epsilon":
311 target = noise 311 target = noise
312 snr_weights = 1
312 elif noise_scheduler.config.prediction_type == "v_prediction": 313 elif noise_scheduler.config.prediction_type == "v_prediction":
313 target = noise_scheduler.get_velocity(latents, noise, timesteps) 314 target = noise_scheduler.get_velocity(latents, noise, timesteps)
315 snr = target / (1 - target)
316 snr /= snr + 1
317 snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device))
314 else: 318 else:
315 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 319 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
316 320
@@ -320,16 +324,17 @@ def loss_step(
320 target, target_prior = torch.chunk(target, 2, dim=0) 324 target, target_prior = torch.chunk(target, 2, dim=0)
321 325
322 # Compute instance loss 326 # Compute instance loss
323 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 327 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
324 328
325 # Compute prior loss 329 # Compute prior loss
326 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 330 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none")
327 331
328 # Add the prior loss to the instance loss. 332 # Add the prior loss to the instance loss.
329 loss = loss + prior_loss_weight * prior_loss 333 loss = loss + prior_loss_weight * prior_loss
330 else: 334 else:
331 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 335 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
332 336
337 loss = (snr_weights * loss).mean()
333 acc = (model_pred == target).float().mean() 338 acc = (model_pred == target).float().mean()
334 339
335 return loss, acc, bsz 340 return loss, acc, bsz