From ec762fd3afaa6df0715fa1cbe9e6f088b9276ea0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 28 Apr 2023 16:22:06 +0200 Subject: Fixed loss/acc logging --- training/functional.py | 32 ++++++++++++++++---------------- training/util.py | 7 ++++++- 2 files changed, 22 insertions(+), 17 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 3036ed9..6ae35a0 100644 --- a/training/functional.py +++ b/training/functional.py @@ -468,7 +468,8 @@ def train_loop( callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) - num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 + num_val_steps_per_epoch = math.ceil( + len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs @@ -476,8 +477,8 @@ def train_loop( global_step = 0 cache = {} - best_acc = avg_acc.avg - best_acc_val = avg_acc_val.avg + best_acc = avg_acc.max + best_acc_val = avg_acc_val.max local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), @@ -591,35 +592,34 @@ def train_loop( on_after_epoch() if val_dataloader is not None: - cur_loss_val = AverageMeter() - cur_acc_val = AverageMeter() + cur_loss_val = AverageMeter(power=1) + cur_acc_val = AverageMeter(power=1) with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, cache, True) - - loss = loss.detach_() - acc = acc.detach_() + loss /= gradient_accumulation_steps cur_loss_val.update(loss.item(), bsz) cur_acc_val.update(acc.item(), bsz) - avg_loss_val.update(loss.item(), bsz) - avg_acc_val.update(acc.item(), bsz) - - local_progress_bar.update(1) - global_progress_bar.update(1) - logs = { - "val/loss": avg_loss_val.avg, - "val/acc": avg_acc_val.avg, "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) + if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): + local_progress_bar.update(1) + global_progress_bar.update(1) + + avg_loss_val.update(cur_loss_val.avg) + avg_acc_val.update(cur_acc_val.avg) + logs["val/cur_loss"] = cur_loss_val.avg logs["val/cur_acc"] = cur_acc_val.avg + logs["val/loss"] = avg_loss_val.avg + logs["val/acc"] = avg_acc_val.avg accelerator.log(logs, step=global_step) diff --git a/training/util.py b/training/util.py index 61f1533..0b6bea9 100644 --- a/training/util.py +++ b/training/util.py @@ -1,5 +1,6 @@ from pathlib import Path import json +import math from typing import Iterable, Any from contextlib import contextmanager @@ -23,7 +24,9 @@ class AverageMeter: def reset(self): self.step = 0 - self.avg = 0 + self.min = math.inf + self.max = 0.0 + self.avg = 0.0 def get_decay(self): if self.step <= 0: @@ -35,6 +38,8 @@ class AverageMeter: for _ in range(n): self.step += n self.avg += (val - self.avg) * self.get_decay() + self.min = min(self.min, self.avg) + self.max = max(self.max, self.avg) class EMAModel(EMAModel_): -- cgit v1.2.3-70-g09d2