summaryrefslogtreecommitdiffstats
path: root/training/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/common.py')
-rw-r--r--training/common.py37
1 files changed, 18 insertions, 19 deletions
diff --git a/training/common.py b/training/common.py
index 8083137..5d1e3f9 100644
--- a/training/common.py
+++ b/training/common.py
@@ -316,30 +316,29 @@ def train_loop(
316 cur_loss_val = AverageMeter() 316 cur_loss_val = AverageMeter()
317 cur_acc_val = AverageMeter() 317 cur_acc_val = AverageMeter()
318 318
319 with torch.inference_mode(): 319 with torch.inference_mode(), on_eval():
320 with on_eval(): 320 for step, batch in enumerate(val_dataloader):
321 for step, batch in enumerate(val_dataloader): 321 loss, acc, bsz = loss_step(step, batch, True)
322 loss, acc, bsz = loss_step(step, batch, True)
323 322
324 loss = loss.detach_() 323 loss = loss.detach_()
325 acc = acc.detach_() 324 acc = acc.detach_()
326 325
327 cur_loss_val.update(loss, bsz) 326 cur_loss_val.update(loss, bsz)
328 cur_acc_val.update(acc, bsz) 327 cur_acc_val.update(acc, bsz)
329 328
330 avg_loss_val.update(loss, bsz) 329 avg_loss_val.update(loss, bsz)
331 avg_acc_val.update(acc, bsz) 330 avg_acc_val.update(acc, bsz)
332 331
333 local_progress_bar.update(1) 332 local_progress_bar.update(1)
334 global_progress_bar.update(1) 333 global_progress_bar.update(1)
335 334
336 logs = { 335 logs = {
337 "val/loss": avg_loss_val.avg.item(), 336 "val/loss": avg_loss_val.avg.item(),
338 "val/acc": avg_acc_val.avg.item(), 337 "val/acc": avg_acc_val.avg.item(),
339 "val/cur_loss": loss.item(), 338 "val/cur_loss": loss.item(),
340 "val/cur_acc": acc.item(), 339 "val/cur_acc": acc.item(),
341 } 340 }
342 local_progress_bar.set_postfix(**logs) 341 local_progress_bar.set_postfix(**logs)
343 342
344 logs["val/cur_loss"] = cur_loss_val.avg.item() 343 logs["val/cur_loss"] = cur_loss_val.avg.item()
345 logs["val/cur_acc"] = cur_acc_val.avg.item() 344 logs["val/cur_acc"] = cur_acc_val.avg.item()