diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index 4565612..2d6553a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -309,8 +309,12 @@ def loss_step( | |||
309 | # Get the target for loss depending on the prediction type | 309 | # Get the target for loss depending on the prediction type |
310 | if noise_scheduler.config.prediction_type == "epsilon": | 310 | if noise_scheduler.config.prediction_type == "epsilon": |
311 | target = noise | 311 | target = noise |
312 | snr_weights = 1 | ||
312 | elif noise_scheduler.config.prediction_type == "v_prediction": | 313 | elif noise_scheduler.config.prediction_type == "v_prediction": |
313 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 314 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
315 | snr = target / (1 - target) | ||
316 | snr /= snr + 1 | ||
317 | snr_weights = torch.minimum(snr, torch.tensor([5], device=latents.device)) | ||
314 | else: | 318 | else: |
315 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 319 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
316 | 320 | ||
@@ -320,16 +324,17 @@ def loss_step( | |||
320 | target, target_prior = torch.chunk(target, 2, dim=0) | 324 | target, target_prior = torch.chunk(target, 2, dim=0) |
321 | 325 | ||
322 | # Compute instance loss | 326 | # Compute instance loss |
323 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 327 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
324 | 328 | ||
325 | # Compute prior loss | 329 | # Compute prior loss |
326 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 330 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") |
327 | 331 | ||
328 | # Add the prior loss to the instance loss. | 332 | # Add the prior loss to the instance loss. |
329 | loss = loss + prior_loss_weight * prior_loss | 333 | loss = loss + prior_loss_weight * prior_loss |
330 | else: | 334 | else: |
331 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 335 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
332 | 336 | ||
337 | loss = (snr_weights * loss).mean() | ||
333 | acc = (model_pred == target).float().mean() | 338 | acc = (model_pred == target).float().mean() |
334 | 339 | ||
335 | return loss, acc, bsz | 340 | return loss, acc, bsz |