From 30098b1d611853c0d3a4687d84582e1c1cf1b938 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Dec 2022 11:48:33 +0100 Subject: Added validation phase to learn rate finder --- training/lr.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) (limited to 'training') diff --git a/training/lr.py b/training/lr.py index dd37baa..5343f24 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,20 +1,22 @@ +import matplotlib.pyplot as plt import numpy as np +import torch from torch.optim.lr_scheduler import LambdaLR from tqdm.auto import tqdm -import matplotlib.pyplot as plt from training.util import AverageMeter class LRFinder(): - def __init__(self, accelerator, model, optimizer, train_dataloader, loss_fn): + def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): self.accelerator = accelerator self.model = model self.optimizer = optimizer self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader self.loss_fn = loss_fn - def run(self, num_epochs=100, num_steps=1, smooth_f=0.05, diverge_th=5): + def run(self, num_epochs=100, num_train_steps=1, num_val_steps=1, smooth_f=0.05, diverge_th=5): best_loss = None lrs = [] losses = [] @@ -22,7 +24,7 @@ class LRFinder(): lr_scheduler = get_exponential_schedule(self.optimizer, num_epochs) progress_bar = tqdm( - range(num_epochs * num_steps), + range(num_epochs * (num_train_steps + num_val_steps)), disable=not self.accelerator.is_local_main_process, dynamic_ncols=True ) @@ -33,6 +35,8 @@ class LRFinder(): avg_loss = AverageMeter() + self.model.train() + for step, batch in enumerate(self.train_dataloader): with self.accelerator.accumulate(self.model): loss, acc, bsz = self.loss_fn(batch) @@ -42,13 +46,24 @@ class LRFinder(): self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) - avg_loss.update(loss.detach_(), bsz) + if self.accelerator.sync_gradients: + progress_bar.update(1) - if step >= num_steps: + if step >= num_train_steps: break - if self.accelerator.sync_gradients: - progress_bar.update(1) + self.model.eval() + + with torch.inference_mode(): + for step, batch in enumerate(self.val_dataloader): + loss, acc, bsz = self.loss_fn(batch) + avg_loss.update(loss.detach_(), bsz) + + if self.accelerator.sync_gradients: + progress_bar.update(1) + + if step >= num_val_steps: + break lr_scheduler.step() @@ -104,9 +119,6 @@ class LRFinder(): ax.set_xlabel("Learning rate") ax.set_ylabel("Loss") - if fig is not None: - plt.show() - def get_exponential_schedule(optimizer, num_epochs, last_epoch=-1): def lr_lambda(current_epoch: int): -- cgit v1.2.3-70-g09d2