From a5cdb510002324b6e6cf8297ee4cfd6f25330ed2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Feb 2023 12:03:00 +0100 Subject: Fix --- training/functional.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 3f5fa7e..e7c4320 100644 --- a/training/functional.py +++ b/training/functional.py @@ -375,7 +375,6 @@ def train_loop( num_val_steps = num_val_steps_per_epoch * num_epochs global_step = 0 - train_step = 0 avg_loss = AverageMeter() avg_acc = AverageMeter() @@ -439,11 +438,11 @@ def train_loop( loss, acc, bsz = loss_step(step, batch) loss /= gradient_accumulation_steps + accelerator.backward(loss) + avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) - accelerator.backward(loss) - logs = { "train/loss": avg_loss.avg.item(), "train/acc": avg_acc.avg.item(), @@ -455,9 +454,7 @@ def train_loop( local_progress_bar.set_postfix(**logs) - train_step += 1 - - if train_step % gradient_accumulation_steps == 0: + if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) optimizer.step() -- cgit v1.2.3-54-g00ecf