diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/training/functional.py b/training/functional.py index 3036ed9..6ae35a0 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -468,7 +468,8 @@ def train_loop( | |||
468 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 468 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
469 | ): | 469 | ): |
470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
471 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 | 471 | num_val_steps_per_epoch = math.ceil( |
472 | len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 | ||
472 | 473 | ||
473 | num_training_steps = num_training_steps_per_epoch * num_epochs | 474 | num_training_steps = num_training_steps_per_epoch * num_epochs |
474 | num_val_steps = num_val_steps_per_epoch * num_epochs | 475 | num_val_steps = num_val_steps_per_epoch * num_epochs |
@@ -476,8 +477,8 @@ def train_loop( | |||
476 | global_step = 0 | 477 | global_step = 0 |
477 | cache = {} | 478 | cache = {} |
478 | 479 | ||
479 | best_acc = avg_acc.avg | 480 | best_acc = avg_acc.max |
480 | best_acc_val = avg_acc_val.avg | 481 | best_acc_val = avg_acc_val.max |
481 | 482 | ||
482 | local_progress_bar = tqdm( | 483 | local_progress_bar = tqdm( |
483 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 484 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
@@ -591,35 +592,34 @@ def train_loop( | |||
591 | on_after_epoch() | 592 | on_after_epoch() |
592 | 593 | ||
593 | if val_dataloader is not None: | 594 | if val_dataloader is not None: |
594 | cur_loss_val = AverageMeter() | 595 | cur_loss_val = AverageMeter(power=1) |
595 | cur_acc_val = AverageMeter() | 596 | cur_acc_val = AverageMeter(power=1) |
596 | 597 | ||
597 | with torch.inference_mode(), on_eval(): | 598 | with torch.inference_mode(), on_eval(): |
598 | for step, batch in enumerate(val_dataloader): | 599 | for step, batch in enumerate(val_dataloader): |
599 | loss, acc, bsz = loss_step(step, batch, cache, True) | 600 | loss, acc, bsz = loss_step(step, batch, cache, True) |
600 | 601 | loss /= gradient_accumulation_steps | |
601 | loss = loss.detach_() | ||
602 | acc = acc.detach_() | ||
603 | 602 | ||
604 | cur_loss_val.update(loss.item(), bsz) | 603 | cur_loss_val.update(loss.item(), bsz) |
605 | cur_acc_val.update(acc.item(), bsz) | 604 | cur_acc_val.update(acc.item(), bsz) |
606 | 605 | ||
607 | avg_loss_val.update(loss.item(), bsz) | ||
608 | avg_acc_val.update(acc.item(), bsz) | ||
609 | |||
610 | local_progress_bar.update(1) | ||
611 | global_progress_bar.update(1) | ||
612 | |||
613 | logs = { | 606 | logs = { |
614 | "val/loss": avg_loss_val.avg, | ||
615 | "val/acc": avg_acc_val.avg, | ||
616 | "val/cur_loss": loss.item(), | 607 | "val/cur_loss": loss.item(), |
617 | "val/cur_acc": acc.item(), | 608 | "val/cur_acc": acc.item(), |
618 | } | 609 | } |
619 | local_progress_bar.set_postfix(**logs) | 610 | local_progress_bar.set_postfix(**logs) |
620 | 611 | ||
612 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): | ||
613 | local_progress_bar.update(1) | ||
614 | global_progress_bar.update(1) | ||
615 | |||
616 | avg_loss_val.update(cur_loss_val.avg) | ||
617 | avg_acc_val.update(cur_acc_val.avg) | ||
618 | |||
621 | logs["val/cur_loss"] = cur_loss_val.avg | 619 | logs["val/cur_loss"] = cur_loss_val.avg |
622 | logs["val/cur_acc"] = cur_acc_val.avg | 620 | logs["val/cur_acc"] = cur_acc_val.avg |
621 | logs["val/loss"] = avg_loss_val.avg | ||
622 | logs["val/acc"] = avg_acc_val.avg | ||
623 | 623 | ||
624 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
625 | 625 | ||