summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py9
1 files changed, 3 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py
index 3f5fa7e..e7c4320 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -375,7 +375,6 @@ def train_loop(
375 num_val_steps = num_val_steps_per_epoch * num_epochs 375 num_val_steps = num_val_steps_per_epoch * num_epochs
376 376
377 global_step = 0 377 global_step = 0
378 train_step = 0
379 378
380 avg_loss = AverageMeter() 379 avg_loss = AverageMeter()
381 avg_acc = AverageMeter() 380 avg_acc = AverageMeter()
@@ -439,11 +438,11 @@ def train_loop(
439 loss, acc, bsz = loss_step(step, batch) 438 loss, acc, bsz = loss_step(step, batch)
440 loss /= gradient_accumulation_steps 439 loss /= gradient_accumulation_steps
441 440
441 accelerator.backward(loss)
442
442 avg_loss.update(loss.detach_(), bsz) 443 avg_loss.update(loss.detach_(), bsz)
443 avg_acc.update(acc.detach_(), bsz) 444 avg_acc.update(acc.detach_(), bsz)
444 445
445 accelerator.backward(loss)
446
447 logs = { 446 logs = {
448 "train/loss": avg_loss.avg.item(), 447 "train/loss": avg_loss.avg.item(),
449 "train/acc": avg_acc.avg.item(), 448 "train/acc": avg_acc.avg.item(),
@@ -455,9 +454,7 @@ def train_loop(
455 454
456 local_progress_bar.set_postfix(**logs) 455 local_progress_bar.set_postfix(**logs)
457 456
458 train_step += 1 457 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)):
459
460 if train_step % gradient_accumulation_steps == 0:
461 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) 458 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch)
462 459
463 optimizer.step() 460 optimizer.step()