From ec762fd3afaa6df0715fa1cbe9e6f088b9276ea0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 28 Apr 2023 16:22:06 +0200 Subject: Fixed loss/acc logging --- training/functional.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 3036ed9..6ae35a0 100644 --- a/training/functional.py +++ b/training/functional.py @@ -468,7 +468,8 @@ def train_loop( callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) - num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 + num_val_steps_per_epoch = math.ceil( + len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs @@ -476,8 +477,8 @@ def train_loop( global_step = 0 cache = {} - best_acc = avg_acc.avg - best_acc_val = avg_acc_val.avg + best_acc = avg_acc.max + best_acc_val = avg_acc_val.max local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), @@ -591,35 +592,34 @@ def train_loop( on_after_epoch() if val_dataloader is not None: - cur_loss_val = AverageMeter() - cur_acc_val = AverageMeter() + cur_loss_val = AverageMeter(power=1) + cur_acc_val = AverageMeter(power=1) with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, cache, True) - - loss = loss.detach_() - acc = acc.detach_() + loss /= gradient_accumulation_steps cur_loss_val.update(loss.item(), bsz) cur_acc_val.update(acc.item(), bsz) - avg_loss_val.update(loss.item(), bsz) - avg_acc_val.update(acc.item(), bsz) - - local_progress_bar.update(1) - global_progress_bar.update(1) - logs = { - "val/loss": avg_loss_val.avg, - "val/acc": avg_acc_val.avg, "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) + if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): + local_progress_bar.update(1) + global_progress_bar.update(1) + + avg_loss_val.update(cur_loss_val.avg) + avg_acc_val.update(cur_acc_val.avg) + logs["val/cur_loss"] = cur_loss_val.avg logs["val/cur_acc"] = cur_acc_val.avg + logs["val/loss"] = avg_loss_val.avg + logs["val/acc"] = avg_acc_val.avg accelerator.log(logs, step=global_step) -- cgit v1.2.3-54-g00ecf