From 3575d041f1507811b577fd2c653171fb51c0a386 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 20 Jan 2023 14:26:17 +0100 Subject: Restored LR finder --- training/functional.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index fb135c4..c373ac9 100644 --- a/training/functional.py +++ b/training/functional.py @@ -7,7 +7,6 @@ from pathlib import Path import itertools import torch -import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader @@ -373,8 +372,12 @@ def train_loop( avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() - max_acc = 0.0 - max_acc_val = 0.0 + best_acc = 0.0 + best_acc_val = 0.0 + + lrs = [] + losses = [] + accs = [] local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), @@ -457,6 +460,8 @@ def train_loop( accelerator.wait_for_everyone() + lrs.append(lr_scheduler.get_last_lr()[0]) + on_after_epoch(lr_scheduler.get_last_lr()[0]) if val_dataloader is not None: @@ -498,18 +503,24 @@ def train_loop( global_progress_bar.clear() if accelerator.is_main_process: - if avg_acc_val.avg.item() > max_acc_val: + if avg_acc_val.avg.item() > best_acc_val: accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") - max_acc_val = avg_acc_val.avg.item() + best_acc_val = avg_acc_val.avg.item() + + losses.append(avg_loss_val.avg.item()) + accs.append(avg_acc_val.avg.item()) else: if accelerator.is_main_process: - if avg_acc.avg.item() > max_acc: + if avg_acc.avg.item() > best_acc: accelerator.print( - f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") + f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") - max_acc = avg_acc.avg.item() + best_acc = avg_acc.avg.item() + + losses.append(avg_loss.avg.item()) + accs.append(avg_acc.avg.item()) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: @@ -523,6 +534,8 @@ def train_loop( on_checkpoint(global_step + global_step_offset, "end") raise KeyboardInterrupt + return lrs, losses, accs + def train( accelerator: Accelerator, @@ -582,7 +595,7 @@ def train( if accelerator.is_main_process: accelerator.init_trackers(project) - train_loop( + metrics = train_loop( accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, @@ -598,3 +611,5 @@ def train( accelerator.end_training() accelerator.free_memory() + + return metrics -- cgit v1.2.3-54-g00ecf