diff options
-rw-r--r-- | training/functional.py | 32 | ||||
-rw-r--r-- | training/util.py | 7 |
2 files changed, 22 insertions, 17 deletions
diff --git a/training/functional.py b/training/functional.py index 3036ed9..6ae35a0 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -468,7 +468,8 @@ def train_loop( | |||
468 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 468 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
469 | ): | 469 | ): |
470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
471 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 | 471 | num_val_steps_per_epoch = math.ceil( |
472 | len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 | ||
472 | 473 | ||
473 | num_training_steps = num_training_steps_per_epoch * num_epochs | 474 | num_training_steps = num_training_steps_per_epoch * num_epochs |
474 | num_val_steps = num_val_steps_per_epoch * num_epochs | 475 | num_val_steps = num_val_steps_per_epoch * num_epochs |
@@ -476,8 +477,8 @@ def train_loop( | |||
476 | global_step = 0 | 477 | global_step = 0 |
477 | cache = {} | 478 | cache = {} |
478 | 479 | ||
479 | best_acc = avg_acc.avg | 480 | best_acc = avg_acc.max |
480 | best_acc_val = avg_acc_val.avg | 481 | best_acc_val = avg_acc_val.max |
481 | 482 | ||
482 | local_progress_bar = tqdm( | 483 | local_progress_bar = tqdm( |
483 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 484 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
@@ -591,35 +592,34 @@ def train_loop( | |||
591 | on_after_epoch() | 592 | on_after_epoch() |
592 | 593 | ||
593 | if val_dataloader is not None: | 594 | if val_dataloader is not None: |
594 | cur_loss_val = AverageMeter() | 595 | cur_loss_val = AverageMeter(power=1) |
595 | cur_acc_val = AverageMeter() | 596 | cur_acc_val = AverageMeter(power=1) |
596 | 597 | ||
597 | with torch.inference_mode(), on_eval(): | 598 | with torch.inference_mode(), on_eval(): |
598 | for step, batch in enumerate(val_dataloader): | 599 | for step, batch in enumerate(val_dataloader): |
599 | loss, acc, bsz = loss_step(step, batch, cache, True) | 600 | loss, acc, bsz = loss_step(step, batch, cache, True) |
600 | 601 | loss /= gradient_accumulation_steps | |
601 | loss = loss.detach_() | ||
602 | acc = acc.detach_() | ||
603 | 602 | ||
604 | cur_loss_val.update(loss.item(), bsz) | 603 | cur_loss_val.update(loss.item(), bsz) |
605 | cur_acc_val.update(acc.item(), bsz) | 604 | cur_acc_val.update(acc.item(), bsz) |
606 | 605 | ||
607 | avg_loss_val.update(loss.item(), bsz) | ||
608 | avg_acc_val.update(acc.item(), bsz) | ||
609 | |||
610 | local_progress_bar.update(1) | ||
611 | global_progress_bar.update(1) | ||
612 | |||
613 | logs = { | 606 | logs = { |
614 | "val/loss": avg_loss_val.avg, | ||
615 | "val/acc": avg_acc_val.avg, | ||
616 | "val/cur_loss": loss.item(), | 607 | "val/cur_loss": loss.item(), |
617 | "val/cur_acc": acc.item(), | 608 | "val/cur_acc": acc.item(), |
618 | } | 609 | } |
619 | local_progress_bar.set_postfix(**logs) | 610 | local_progress_bar.set_postfix(**logs) |
620 | 611 | ||
612 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): | ||
613 | local_progress_bar.update(1) | ||
614 | global_progress_bar.update(1) | ||
615 | |||
616 | avg_loss_val.update(cur_loss_val.avg) | ||
617 | avg_acc_val.update(cur_acc_val.avg) | ||
618 | |||
621 | logs["val/cur_loss"] = cur_loss_val.avg | 619 | logs["val/cur_loss"] = cur_loss_val.avg |
622 | logs["val/cur_acc"] = cur_acc_val.avg | 620 | logs["val/cur_acc"] = cur_acc_val.avg |
621 | logs["val/loss"] = avg_loss_val.avg | ||
622 | logs["val/acc"] = avg_acc_val.avg | ||
623 | 623 | ||
624 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
625 | 625 | ||
diff --git a/training/util.py b/training/util.py index 61f1533..0b6bea9 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,5 +1,6 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import math | ||
3 | from typing import Iterable, Any | 4 | from typing import Iterable, Any |
4 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
5 | 6 | ||
@@ -23,7 +24,9 @@ class AverageMeter: | |||
23 | 24 | ||
24 | def reset(self): | 25 | def reset(self): |
25 | self.step = 0 | 26 | self.step = 0 |
26 | self.avg = 0 | 27 | self.min = math.inf |
28 | self.max = 0.0 | ||
29 | self.avg = 0.0 | ||
27 | 30 | ||
28 | def get_decay(self): | 31 | def get_decay(self): |
29 | if self.step <= 0: | 32 | if self.step <= 0: |
@@ -35,6 +38,8 @@ class AverageMeter: | |||
35 | for _ in range(n): | 38 | for _ in range(n): |
36 | self.step += n | 39 | self.step += n |
37 | self.avg += (val - self.avg) * self.get_decay() | 40 | self.avg += (val - self.avg) * self.get_decay() |
41 | self.min = min(self.min, self.avg) | ||
42 | self.max = max(self.max, self.avg) | ||
38 | 43 | ||
39 | 44 | ||
40 | class EMAModel(EMAModel_): | 45 | class EMAModel(EMAModel_): |