From a2c240c8c55dfe930657f66372975d6f26feb168 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 10:02:30 +0100 Subject: TI: Prepare UNet with Accelerate as well --- training/common.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) (limited to 'training/common.py') diff --git a/training/common.py b/training/common.py index 8083137..5d1e3f9 100644 --- a/training/common.py +++ b/training/common.py @@ -316,30 +316,29 @@ def train_loop( cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() - with torch.inference_mode(): - with on_eval(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loss_step(step, batch, True) + with torch.inference_mode(), on_eval(): + for step, batch in enumerate(val_dataloader): + loss, acc, bsz = loss_step(step, batch, True) - loss = loss.detach_() - acc = acc.detach_() + loss = loss.detach_() + acc = acc.detach_() - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) - local_progress_bar.update(1) - global_progress_bar.update(1) + local_progress_bar.update(1) + global_progress_bar.update(1) - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) + logs = { + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } + local_progress_bar.set_postfix(**logs) logs["val/cur_loss"] = cur_loss_val.avg.item() logs["val/cur_acc"] = cur_acc_val.avg.item() -- cgit v1.2.3-54-g00ecf