From dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 30 Dec 2022 13:48:26 +0100 Subject: Training script improvements --- train_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 9a42cae..de878a4 100644 --- a/train_lora.py +++ b/train_lora.py @@ -810,7 +810,7 @@ def main(): target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") -- cgit v1.2.3-54-g00ecf