diff options
Diffstat (limited to 'training/common.py')
| -rw-r--r-- | training/common.py | 37 |
1 files changed, 18 insertions, 19 deletions
diff --git a/training/common.py b/training/common.py index 8083137..5d1e3f9 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -316,30 +316,29 @@ def train_loop( | |||
| 316 | cur_loss_val = AverageMeter() | 316 | cur_loss_val = AverageMeter() |
| 317 | cur_acc_val = AverageMeter() | 317 | cur_acc_val = AverageMeter() |
| 318 | 318 | ||
| 319 | with torch.inference_mode(): | 319 | with torch.inference_mode(), on_eval(): |
| 320 | with on_eval(): | 320 | for step, batch in enumerate(val_dataloader): |
| 321 | for step, batch in enumerate(val_dataloader): | 321 | loss, acc, bsz = loss_step(step, batch, True) |
| 322 | loss, acc, bsz = loss_step(step, batch, True) | ||
| 323 | 322 | ||
| 324 | loss = loss.detach_() | 323 | loss = loss.detach_() |
| 325 | acc = acc.detach_() | 324 | acc = acc.detach_() |
| 326 | 325 | ||
| 327 | cur_loss_val.update(loss, bsz) | 326 | cur_loss_val.update(loss, bsz) |
| 328 | cur_acc_val.update(acc, bsz) | 327 | cur_acc_val.update(acc, bsz) |
| 329 | 328 | ||
| 330 | avg_loss_val.update(loss, bsz) | 329 | avg_loss_val.update(loss, bsz) |
| 331 | avg_acc_val.update(acc, bsz) | 330 | avg_acc_val.update(acc, bsz) |
| 332 | 331 | ||
| 333 | local_progress_bar.update(1) | 332 | local_progress_bar.update(1) |
| 334 | global_progress_bar.update(1) | 333 | global_progress_bar.update(1) |
| 335 | 334 | ||
| 336 | logs = { | 335 | logs = { |
| 337 | "val/loss": avg_loss_val.avg.item(), | 336 | "val/loss": avg_loss_val.avg.item(), |
| 338 | "val/acc": avg_acc_val.avg.item(), | 337 | "val/acc": avg_acc_val.avg.item(), |
| 339 | "val/cur_loss": loss.item(), | 338 | "val/cur_loss": loss.item(), |
| 340 | "val/cur_acc": acc.item(), | 339 | "val/cur_acc": acc.item(), |
| 341 | } | 340 | } |
| 342 | local_progress_bar.set_postfix(**logs) | 341 | local_progress_bar.set_postfix(**logs) |
| 343 | 342 | ||
| 344 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 343 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
| 345 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 344 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
