summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/training/functional.py b/training/functional.py
index 3036ed9..6ae35a0 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -468,7 +468,8 @@ def train_loop(
468 callbacks: TrainingCallbacks = TrainingCallbacks(), 468 callbacks: TrainingCallbacks = TrainingCallbacks(),
469): 469):
470 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 470 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
471 num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 471 num_val_steps_per_epoch = math.ceil(
472 len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0
472 473
473 num_training_steps = num_training_steps_per_epoch * num_epochs 474 num_training_steps = num_training_steps_per_epoch * num_epochs
474 num_val_steps = num_val_steps_per_epoch * num_epochs 475 num_val_steps = num_val_steps_per_epoch * num_epochs
@@ -476,8 +477,8 @@ def train_loop(
476 global_step = 0 477 global_step = 0
477 cache = {} 478 cache = {}
478 479
479 best_acc = avg_acc.avg 480 best_acc = avg_acc.max
480 best_acc_val = avg_acc_val.avg 481 best_acc_val = avg_acc_val.max
481 482
482 local_progress_bar = tqdm( 483 local_progress_bar = tqdm(
483 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 484 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
@@ -591,35 +592,34 @@ def train_loop(
591 on_after_epoch() 592 on_after_epoch()
592 593
593 if val_dataloader is not None: 594 if val_dataloader is not None:
594 cur_loss_val = AverageMeter() 595 cur_loss_val = AverageMeter(power=1)
595 cur_acc_val = AverageMeter() 596 cur_acc_val = AverageMeter(power=1)
596 597
597 with torch.inference_mode(), on_eval(): 598 with torch.inference_mode(), on_eval():
598 for step, batch in enumerate(val_dataloader): 599 for step, batch in enumerate(val_dataloader):
599 loss, acc, bsz = loss_step(step, batch, cache, True) 600 loss, acc, bsz = loss_step(step, batch, cache, True)
600 601 loss /= gradient_accumulation_steps
601 loss = loss.detach_()
602 acc = acc.detach_()
603 602
604 cur_loss_val.update(loss.item(), bsz) 603 cur_loss_val.update(loss.item(), bsz)
605 cur_acc_val.update(acc.item(), bsz) 604 cur_acc_val.update(acc.item(), bsz)
606 605
607 avg_loss_val.update(loss.item(), bsz)
608 avg_acc_val.update(acc.item(), bsz)
609
610 local_progress_bar.update(1)
611 global_progress_bar.update(1)
612
613 logs = { 606 logs = {
614 "val/loss": avg_loss_val.avg,
615 "val/acc": avg_acc_val.avg,
616 "val/cur_loss": loss.item(), 607 "val/cur_loss": loss.item(),
617 "val/cur_acc": acc.item(), 608 "val/cur_acc": acc.item(),
618 } 609 }
619 local_progress_bar.set_postfix(**logs) 610 local_progress_bar.set_postfix(**logs)
620 611
612 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)):
613 local_progress_bar.update(1)
614 global_progress_bar.update(1)
615
616 avg_loss_val.update(cur_loss_val.avg)
617 avg_acc_val.update(cur_acc_val.avg)
618
621 logs["val/cur_loss"] = cur_loss_val.avg 619 logs["val/cur_loss"] = cur_loss_val.avg
622 logs["val/cur_acc"] = cur_acc_val.avg 620 logs["val/cur_acc"] = cur_acc_val.avg
621 logs["val/loss"] = avg_loss_val.avg
622 logs["val/acc"] = avg_acc_val.avg
623 623
624 accelerator.log(logs, step=global_step) 624 accelerator.log(logs, step=global_step)
625 625