From 4d3d318a4168ef79847737cef2c0ad8a4dafd3e7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 29 Dec 2022 09:00:19 +0100 Subject: Training improvements --- train_ti.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index d7696e5..b1f6a49 100644 --- a/train_ti.py +++ b/train_ti.py @@ -903,12 +903,21 @@ def main(): text_encoder.eval() + cur_loss_val = AverageMeter() + cur_acc_val = AverageMeter() + with torch.inference_mode(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loop(batch) - avg_loss_val.update(loss.detach_(), bsz) - avg_acc_val.update(acc.detach_(), bsz) + loss = loss.detach_() + acc = acc.detach_() + + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) + + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) local_progress_bar.update(1) global_progress_bar.update(1) @@ -921,10 +930,10 @@ def main(): } local_progress_bar.set_postfix(**logs) - accelerator.log({ - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - }, step=global_step) + logs["val/cur_loss"] = cur_loss_val.avg.item() + logs["val/cur_acc"] = cur_acc_val.avg.item() + + accelerator.log(logs, step=global_step) local_progress_bar.clear() global_progress_bar.clear() -- cgit v1.2.3-54-g00ecf