diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/training/functional.py b/training/functional.py index 3036ed9..6ae35a0 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -468,7 +468,8 @@ def train_loop( | |||
| 468 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 468 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 469 | ): | 469 | ): |
| 470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
| 471 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 | 471 | num_val_steps_per_epoch = math.ceil( |
| 472 | len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 | ||
| 472 | 473 | ||
| 473 | num_training_steps = num_training_steps_per_epoch * num_epochs | 474 | num_training_steps = num_training_steps_per_epoch * num_epochs |
| 474 | num_val_steps = num_val_steps_per_epoch * num_epochs | 475 | num_val_steps = num_val_steps_per_epoch * num_epochs |
| @@ -476,8 +477,8 @@ def train_loop( | |||
| 476 | global_step = 0 | 477 | global_step = 0 |
| 477 | cache = {} | 478 | cache = {} |
| 478 | 479 | ||
| 479 | best_acc = avg_acc.avg | 480 | best_acc = avg_acc.max |
| 480 | best_acc_val = avg_acc_val.avg | 481 | best_acc_val = avg_acc_val.max |
| 481 | 482 | ||
| 482 | local_progress_bar = tqdm( | 483 | local_progress_bar = tqdm( |
| 483 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 484 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
| @@ -591,35 +592,34 @@ def train_loop( | |||
| 591 | on_after_epoch() | 592 | on_after_epoch() |
| 592 | 593 | ||
| 593 | if val_dataloader is not None: | 594 | if val_dataloader is not None: |
| 594 | cur_loss_val = AverageMeter() | 595 | cur_loss_val = AverageMeter(power=1) |
| 595 | cur_acc_val = AverageMeter() | 596 | cur_acc_val = AverageMeter(power=1) |
| 596 | 597 | ||
| 597 | with torch.inference_mode(), on_eval(): | 598 | with torch.inference_mode(), on_eval(): |
| 598 | for step, batch in enumerate(val_dataloader): | 599 | for step, batch in enumerate(val_dataloader): |
| 599 | loss, acc, bsz = loss_step(step, batch, cache, True) | 600 | loss, acc, bsz = loss_step(step, batch, cache, True) |
| 600 | 601 | loss /= gradient_accumulation_steps | |
| 601 | loss = loss.detach_() | ||
| 602 | acc = acc.detach_() | ||
| 603 | 602 | ||
| 604 | cur_loss_val.update(loss.item(), bsz) | 603 | cur_loss_val.update(loss.item(), bsz) |
| 605 | cur_acc_val.update(acc.item(), bsz) | 604 | cur_acc_val.update(acc.item(), bsz) |
| 606 | 605 | ||
| 607 | avg_loss_val.update(loss.item(), bsz) | ||
| 608 | avg_acc_val.update(acc.item(), bsz) | ||
| 609 | |||
| 610 | local_progress_bar.update(1) | ||
| 611 | global_progress_bar.update(1) | ||
| 612 | |||
| 613 | logs = { | 606 | logs = { |
| 614 | "val/loss": avg_loss_val.avg, | ||
| 615 | "val/acc": avg_acc_val.avg, | ||
| 616 | "val/cur_loss": loss.item(), | 607 | "val/cur_loss": loss.item(), |
| 617 | "val/cur_acc": acc.item(), | 608 | "val/cur_acc": acc.item(), |
| 618 | } | 609 | } |
| 619 | local_progress_bar.set_postfix(**logs) | 610 | local_progress_bar.set_postfix(**logs) |
| 620 | 611 | ||
| 612 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): | ||
| 613 | local_progress_bar.update(1) | ||
| 614 | global_progress_bar.update(1) | ||
| 615 | |||
| 616 | avg_loss_val.update(cur_loss_val.avg) | ||
| 617 | avg_acc_val.update(cur_acc_val.avg) | ||
| 618 | |||
| 621 | logs["val/cur_loss"] = cur_loss_val.avg | 619 | logs["val/cur_loss"] = cur_loss_val.avg |
| 622 | logs["val/cur_acc"] = cur_acc_val.avg | 620 | logs["val/cur_acc"] = cur_acc_val.avg |
| 621 | logs["val/loss"] = avg_loss_val.avg | ||
| 622 | logs["val/acc"] = avg_acc_val.avg | ||
| 623 | 623 | ||
| 624 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
| 625 | 625 | ||
