From 3575d041f1507811b577fd2c653171fb51c0a386 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 20 Jan 2023 14:26:17 +0100 Subject: Restored LR finder --- training/lr.py | 266 +++++++-------------------------------------------------- 1 file changed, 32 insertions(+), 234 deletions(-) (limited to 'training/lr.py') diff --git a/training/lr.py b/training/lr.py index 9690738..f5b362f 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,238 +1,36 @@ -import math -from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union -from functools import partial +from pathlib import Path import matplotlib.pyplot as plt -import numpy as np -import torch -from torch.optim.lr_scheduler import LambdaLR -from tqdm.auto import tqdm -from training.functional import TrainingCallbacks -from training.util import AverageMeter - -def noop(*args, **kwards): - pass - - -def noop_ctx(*args, **kwards): - return nullcontext() - - -class LRFinder(): - def __init__( - self, - accelerator, - optimizer, - train_dataloader, - val_dataloader, - loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], - callbacks: TrainingCallbacks = TrainingCallbacks() - ): - self.accelerator = accelerator - self.model = callbacks.on_model() - self.optimizer = optimizer - self.train_dataloader = train_dataloader - self.val_dataloader = val_dataloader - self.loss_fn = loss_fn - self.callbacks = callbacks - - # self.model_state = copy.deepcopy(model.state_dict()) - # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) - - def run( - self, - end_lr, - skip_start: int = 10, - skip_end: int = 5, - num_epochs: int = 100, - num_train_batches: int = math.inf, - num_val_batches: int = math.inf, - smooth_f: float = 0.05, - ): - best_loss = None - best_acc = None - - lrs = [] - losses = [] - accs = [] - - lr_scheduler = get_exponential_schedule( - self.optimizer, - end_lr, - num_epochs * min(num_train_batches, len(self.train_dataloader)) - ) - - steps = min(num_train_batches, len(self.train_dataloader)) - steps += min(num_val_batches, len(self.val_dataloader)) - steps *= num_epochs - - progress_bar = tqdm( - range(steps), - disable=not self.accelerator.is_local_main_process, - dynamic_ncols=True - ) - progress_bar.set_description("Epoch X / Y") - - self.callbacks.on_prepare() - - on_train = self.callbacks.on_train - on_before_optimize = self.callbacks.on_before_optimize - on_after_optimize = self.callbacks.on_after_optimize - on_eval = self.callbacks.on_eval - - for epoch in range(num_epochs): - progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") - - avg_loss = AverageMeter() - avg_acc = AverageMeter() - - self.model.train() - - with on_train(epoch): - for step, batch in enumerate(self.train_dataloader): - if step >= num_train_batches: - break - - with self.accelerator.accumulate(self.model): - loss, acc, bsz = self.loss_fn(step, batch) - - self.accelerator.backward(loss) - - on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) - - self.optimizer.step() - lr_scheduler.step() - self.optimizer.zero_grad(set_to_none=True) - - if self.accelerator.sync_gradients: - on_after_optimize(lr_scheduler.get_last_lr()[0]) - - progress_bar.update(1) - - self.model.eval() - - with torch.inference_mode(): - with on_eval(): - for step, batch in enumerate(self.val_dataloader): - if step >= num_val_batches: - break - - loss, acc, bsz = self.loss_fn(step, batch, True) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) - - progress_bar.update(1) - - loss = avg_loss.avg.item() - acc = avg_acc.avg.item() - - if epoch == 0: - best_loss = loss - best_acc = acc - else: - if smooth_f > 0: - loss = smooth_f * loss + (1 - smooth_f) * losses[-1] - acc = smooth_f * acc + (1 - smooth_f) * accs[-1] - if loss < best_loss: - best_loss = loss - if acc > best_acc: - best_acc = acc - - lr = lr_scheduler.get_last_lr()[0] - - lrs.append(lr) - losses.append(loss) - accs.append(acc) - - self.accelerator.log({ - "loss": loss, - "acc": acc, - "lr": lr, - }, step=epoch) - - progress_bar.set_postfix({ - "loss": loss, - "loss/best": best_loss, - "acc": acc, - "acc/best": best_acc, - "lr": lr, - }) - - # self.model.load_state_dict(self.model_state) - # self.optimizer.load_state_dict(self.optimizer_state) - - if skip_end == 0: - lrs = lrs[skip_start:] - losses = losses[skip_start:] - accs = accs[skip_start:] - else: - lrs = lrs[skip_start:-skip_end] - losses = losses[skip_start:-skip_end] - accs = accs[skip_start:-skip_end] - - fig, ax_loss = plt.subplots() - ax_acc = ax_loss.twinx() - - ax_loss.plot(lrs, losses, color='red') - ax_loss.set_xscale("log") - ax_loss.set_xlabel(f"Learning rate") - ax_loss.set_ylabel("Loss") - - ax_acc.plot(lrs, accs, color='blue') - ax_acc.set_xscale("log") - ax_acc.set_ylabel("Accuracy") - - print("LR suggestion: steepest gradient") - min_grad_idx = None - - try: - min_grad_idx = np.gradient(np.array(losses)).argmin() - except ValueError: - print( - "Failed to compute the gradients, there might not be enough points." - ) - - try: - max_val_idx = np.array(accs).argmax() - except ValueError: - print( - "Failed to compute the gradients, there might not be enough points." - ) - - if min_grad_idx is not None: - print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx])) - ax_loss.scatter( - lrs[min_grad_idx], - losses[min_grad_idx], - s=75, - marker="o", - color="red", - zorder=3, - label="steepest gradient", - ) - ax_loss.legend() - - if max_val_idx is not None: - print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx])) - ax_acc.scatter( - lrs[max_val_idx], - accs[max_val_idx], - s=75, - marker="o", - color="blue", - zorder=3, - label="maximum", - ) - ax_acc.legend() - - -def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): - def lr_lambda(base_lr: float, current_epoch: int): - return (end_lr / base_lr) ** (current_epoch / num_epochs) - - lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] - - return LambdaLR(optimizer, lr_lambdas, last_epoch) +def plot_metrics( + metrics: tuple[list[float], list[float], list[float]], + output_file: Path, + skip_start: int = 10, + skip_end: int = 5, +): + lrs, losses, accs = metrics + + if skip_end == 0: + lrs = lrs[skip_start:] + losses = losses[skip_start:] + accs = accs[skip_start:] + else: + lrs = lrs[skip_start:-skip_end] + losses = losses[skip_start:-skip_end] + accs = accs[skip_start:-skip_end] + + fig, ax_loss = plt.subplots() + ax_acc = ax_loss.twinx() + + ax_loss.plot(lrs, losses, color='red') + ax_loss.set_xscale("log") + ax_loss.set_xlabel(f"Learning rate") + ax_loss.set_ylabel("Loss") + + ax_acc.plot(lrs, accs, color='blue') + ax_acc.set_xscale("log") + ax_acc.set_ylabel("Accuracy") + + plt.savefig(output_file, dpi=300) + plt.close() -- cgit v1.2.3-54-g00ecf