diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 32 | ||||
| -rw-r--r-- | training/util.py | 7 |
2 files changed, 22 insertions, 17 deletions
diff --git a/training/functional.py b/training/functional.py index 3036ed9..6ae35a0 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -468,7 +468,8 @@ def train_loop( | |||
| 468 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 468 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 469 | ): | 469 | ): |
| 470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
| 471 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 | 471 | num_val_steps_per_epoch = math.ceil( |
| 472 | len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 | ||
| 472 | 473 | ||
| 473 | num_training_steps = num_training_steps_per_epoch * num_epochs | 474 | num_training_steps = num_training_steps_per_epoch * num_epochs |
| 474 | num_val_steps = num_val_steps_per_epoch * num_epochs | 475 | num_val_steps = num_val_steps_per_epoch * num_epochs |
| @@ -476,8 +477,8 @@ def train_loop( | |||
| 476 | global_step = 0 | 477 | global_step = 0 |
| 477 | cache = {} | 478 | cache = {} |
| 478 | 479 | ||
| 479 | best_acc = avg_acc.avg | 480 | best_acc = avg_acc.max |
| 480 | best_acc_val = avg_acc_val.avg | 481 | best_acc_val = avg_acc_val.max |
| 481 | 482 | ||
| 482 | local_progress_bar = tqdm( | 483 | local_progress_bar = tqdm( |
| 483 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 484 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
| @@ -591,35 +592,34 @@ def train_loop( | |||
| 591 | on_after_epoch() | 592 | on_after_epoch() |
| 592 | 593 | ||
| 593 | if val_dataloader is not None: | 594 | if val_dataloader is not None: |
| 594 | cur_loss_val = AverageMeter() | 595 | cur_loss_val = AverageMeter(power=1) |
| 595 | cur_acc_val = AverageMeter() | 596 | cur_acc_val = AverageMeter(power=1) |
| 596 | 597 | ||
| 597 | with torch.inference_mode(), on_eval(): | 598 | with torch.inference_mode(), on_eval(): |
| 598 | for step, batch in enumerate(val_dataloader): | 599 | for step, batch in enumerate(val_dataloader): |
| 599 | loss, acc, bsz = loss_step(step, batch, cache, True) | 600 | loss, acc, bsz = loss_step(step, batch, cache, True) |
| 600 | 601 | loss /= gradient_accumulation_steps | |
| 601 | loss = loss.detach_() | ||
| 602 | acc = acc.detach_() | ||
| 603 | 602 | ||
| 604 | cur_loss_val.update(loss.item(), bsz) | 603 | cur_loss_val.update(loss.item(), bsz) |
| 605 | cur_acc_val.update(acc.item(), bsz) | 604 | cur_acc_val.update(acc.item(), bsz) |
| 606 | 605 | ||
| 607 | avg_loss_val.update(loss.item(), bsz) | ||
| 608 | avg_acc_val.update(acc.item(), bsz) | ||
| 609 | |||
| 610 | local_progress_bar.update(1) | ||
| 611 | global_progress_bar.update(1) | ||
| 612 | |||
| 613 | logs = { | 606 | logs = { |
| 614 | "val/loss": avg_loss_val.avg, | ||
| 615 | "val/acc": avg_acc_val.avg, | ||
| 616 | "val/cur_loss": loss.item(), | 607 | "val/cur_loss": loss.item(), |
| 617 | "val/cur_acc": acc.item(), | 608 | "val/cur_acc": acc.item(), |
| 618 | } | 609 | } |
| 619 | local_progress_bar.set_postfix(**logs) | 610 | local_progress_bar.set_postfix(**logs) |
| 620 | 611 | ||
| 612 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): | ||
| 613 | local_progress_bar.update(1) | ||
| 614 | global_progress_bar.update(1) | ||
| 615 | |||
| 616 | avg_loss_val.update(cur_loss_val.avg) | ||
| 617 | avg_acc_val.update(cur_acc_val.avg) | ||
| 618 | |||
| 621 | logs["val/cur_loss"] = cur_loss_val.avg | 619 | logs["val/cur_loss"] = cur_loss_val.avg |
| 622 | logs["val/cur_acc"] = cur_acc_val.avg | 620 | logs["val/cur_acc"] = cur_acc_val.avg |
| 621 | logs["val/loss"] = avg_loss_val.avg | ||
| 622 | logs["val/acc"] = avg_acc_val.avg | ||
| 623 | 623 | ||
| 624 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
| 625 | 625 | ||
diff --git a/training/util.py b/training/util.py index 61f1533..0b6bea9 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,5 +1,6 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | import math | ||
| 3 | from typing import Iterable, Any | 4 | from typing import Iterable, Any |
| 4 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
| 5 | 6 | ||
| @@ -23,7 +24,9 @@ class AverageMeter: | |||
| 23 | 24 | ||
| 24 | def reset(self): | 25 | def reset(self): |
| 25 | self.step = 0 | 26 | self.step = 0 |
| 26 | self.avg = 0 | 27 | self.min = math.inf |
| 28 | self.max = 0.0 | ||
| 29 | self.avg = 0.0 | ||
| 27 | 30 | ||
| 28 | def get_decay(self): | 31 | def get_decay(self): |
| 29 | if self.step <= 0: | 32 | if self.step <= 0: |
| @@ -35,6 +38,8 @@ class AverageMeter: | |||
| 35 | for _ in range(n): | 38 | for _ in range(n): |
| 36 | self.step += n | 39 | self.step += n |
| 37 | self.avg += (val - self.avg) * self.get_decay() | 40 | self.avg += (val - self.avg) * self.get_decay() |
| 41 | self.min = min(self.min, self.avg) | ||
| 42 | self.max = max(self.max, self.avg) | ||
| 38 | 43 | ||
| 39 | 44 | ||
| 40 | class EMAModel(EMAModel_): | 45 | class EMAModel(EMAModel_): |
