diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index 9a42cae..de878a4 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -810,7 +810,7 @@ def main(): | |||
810 | target, target_prior = torch.chunk(target, 2, dim=0) | 810 | target, target_prior = torch.chunk(target, 2, dim=0) |
811 | 811 | ||
812 | # Compute instance loss | 812 | # Compute instance loss |
813 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 813 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
814 | 814 | ||
815 | # Compute prior loss | 815 | # Compute prior loss |
816 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 816 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |