diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index 4565612..2d6553a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -309,8 +309,12 @@ def loss_step( | |||
| 309 | # Get the target for loss depending on the prediction type | 309 | # Get the target for loss depending on the prediction type |
| 310 | if noise_scheduler.config.prediction_type == "epsilon": | 310 | if noise_scheduler.config.prediction_type == "epsilon": |
| 311 | target = noise | 311 | target = noise |
| 312 | snr_weights = 1 | ||
| 312 | elif noise_scheduler.config.prediction_type == "v_prediction": | 313 | elif noise_scheduler.config.prediction_type == "v_prediction": |
| 313 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 314 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
| 315 | snr = target / (1 - target) | ||
| 316 | snr /= snr + 1 | ||
| 317 | snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device)) | ||
| 314 | else: | 318 | else: |
| 315 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 319 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 316 | 320 | ||
| @@ -320,16 +324,17 @@ def loss_step( | |||
| 320 | target, target_prior = torch.chunk(target, 2, dim=0) | 324 | target, target_prior = torch.chunk(target, 2, dim=0) |
| 321 | 325 | ||
| 322 | # Compute instance loss | 326 | # Compute instance loss |
| 323 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 327 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 324 | 328 | ||
| 325 | # Compute prior loss | 329 | # Compute prior loss |
| 326 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 330 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") |
| 327 | 331 | ||
| 328 | # Add the prior loss to the instance loss. | 332 | # Add the prior loss to the instance loss. |
| 329 | loss = loss + prior_loss_weight * prior_loss | 333 | loss = loss + prior_loss_weight * prior_loss |
| 330 | else: | 334 | else: |
| 331 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 335 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 332 | 336 | ||
| 337 | loss = (snr_weights * loss).mean() | ||
| 333 | acc = (model_pred == target).float().mean() | 338 | acc = (model_pred == target).float().mean() |
| 334 | 339 | ||
| 335 | return loss, acc, bsz | 340 | return loss, acc, bsz |
