summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py32
-rw-r--r--training/util.py7
2 files changed, 22 insertions, 17 deletions
diff --git a/training/functional.py b/training/functional.py
index 3036ed9..6ae35a0 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -468,7 +468,8 @@ def train_loop(
468 callbacks: TrainingCallbacks = TrainingCallbacks(), 468 callbacks: TrainingCallbacks = TrainingCallbacks(),
469): 469):
470 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 470 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
471 num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 471 num_val_steps_per_epoch = math.ceil(
472 len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0
472 473
473 num_training_steps = num_training_steps_per_epoch * num_epochs 474 num_training_steps = num_training_steps_per_epoch * num_epochs
474 num_val_steps = num_val_steps_per_epoch * num_epochs 475 num_val_steps = num_val_steps_per_epoch * num_epochs
@@ -476,8 +477,8 @@ def train_loop(
476 global_step = 0 477 global_step = 0
477 cache = {} 478 cache = {}
478 479
479 best_acc = avg_acc.avg 480 best_acc = avg_acc.max
480 best_acc_val = avg_acc_val.avg 481 best_acc_val = avg_acc_val.max
481 482
482 local_progress_bar = tqdm( 483 local_progress_bar = tqdm(
483 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 484 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
@@ -591,35 +592,34 @@ def train_loop(
591 on_after_epoch() 592 on_after_epoch()
592 593
593 if val_dataloader is not None: 594 if val_dataloader is not None:
594 cur_loss_val = AverageMeter() 595 cur_loss_val = AverageMeter(power=1)
595 cur_acc_val = AverageMeter() 596 cur_acc_val = AverageMeter(power=1)
596 597
597 with torch.inference_mode(), on_eval(): 598 with torch.inference_mode(), on_eval():
598 for step, batch in enumerate(val_dataloader): 599 for step, batch in enumerate(val_dataloader):
599 loss, acc, bsz = loss_step(step, batch, cache, True) 600 loss, acc, bsz = loss_step(step, batch, cache, True)
600 601 loss /= gradient_accumulation_steps
601 loss = loss.detach_()
602 acc = acc.detach_()
603 602
604 cur_loss_val.update(loss.item(), bsz) 603 cur_loss_val.update(loss.item(), bsz)
605 cur_acc_val.update(acc.item(), bsz) 604 cur_acc_val.update(acc.item(), bsz)
606 605
607 avg_loss_val.update(loss.item(), bsz)
608 avg_acc_val.update(acc.item(), bsz)
609
610 local_progress_bar.update(1)
611 global_progress_bar.update(1)
612
613 logs = { 606 logs = {
614 "val/loss": avg_loss_val.avg,
615 "val/acc": avg_acc_val.avg,
616 "val/cur_loss": loss.item(), 607 "val/cur_loss": loss.item(),
617 "val/cur_acc": acc.item(), 608 "val/cur_acc": acc.item(),
618 } 609 }
619 local_progress_bar.set_postfix(**logs) 610 local_progress_bar.set_postfix(**logs)
620 611
612 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)):
613 local_progress_bar.update(1)
614 global_progress_bar.update(1)
615
616 avg_loss_val.update(cur_loss_val.avg)
617 avg_acc_val.update(cur_acc_val.avg)
618
621 logs["val/cur_loss"] = cur_loss_val.avg 619 logs["val/cur_loss"] = cur_loss_val.avg
622 logs["val/cur_acc"] = cur_acc_val.avg 620 logs["val/cur_acc"] = cur_acc_val.avg
621 logs["val/loss"] = avg_loss_val.avg
622 logs["val/acc"] = avg_acc_val.avg
623 623
624 accelerator.log(logs, step=global_step) 624 accelerator.log(logs, step=global_step)
625 625
diff --git a/training/util.py b/training/util.py
index 61f1533..0b6bea9 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,5 +1,6 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import math
3from typing import Iterable, Any 4from typing import Iterable, Any
4from contextlib import contextmanager 5from contextlib import contextmanager
5 6
@@ -23,7 +24,9 @@ class AverageMeter:
23 24
24 def reset(self): 25 def reset(self):
25 self.step = 0 26 self.step = 0
26 self.avg = 0 27 self.min = math.inf
28 self.max = 0.0
29 self.avg = 0.0
27 30
28 def get_decay(self): 31 def get_decay(self):
29 if self.step <= 0: 32 if self.step <= 0:
@@ -35,6 +38,8 @@ class AverageMeter:
35 for _ in range(n): 38 for _ in range(n):
36 self.step += n 39 self.step += n
37 self.avg += (val - self.avg) * self.get_decay() 40 self.avg += (val - self.avg) * self.get_decay()
41 self.min = min(self.min, self.avg)
42 self.max = max(self.max, self.avg)
38 43
39 44
40class EMAModel(EMAModel_): 45class EMAModel(EMAModel_):