From 46d631759f59bc6b65458202641e5f5a9bc30b7b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 2 Jan 2023 20:13:59 +0100 Subject: Fixed LR finder --- training/lr.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) (limited to 'training') diff --git a/training/lr.py b/training/lr.py index fe166ed..acc01a2 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,6 +1,7 @@ import math import copy from typing import Callable +from functools import partial import matplotlib.pyplot as plt import numpy as np @@ -41,7 +42,7 @@ class LRFinder(): def run( self, - min_lr, + end_lr, skip_start: int = 10, skip_end: int = 5, num_epochs: int = 100, @@ -57,7 +58,7 @@ class LRFinder(): losses = [] accs = [] - lr_scheduler = get_exponential_schedule(self.optimizer, min_lr, num_epochs) + lr_scheduler = get_exponential_schedule(self.optimizer, end_lr, num_epochs) steps = min(num_train_batches, len(self.train_dataloader)) steps += min(num_val_batches, len(self.val_dataloader)) @@ -152,29 +153,30 @@ class LRFinder(): print("Stopping early, the loss has diverged") break - if skip_end == 0: - lrs = lrs[skip_start:] - losses = losses[skip_start:] - accs = accs[skip_start:] - else: - lrs = lrs[skip_start:-skip_end] - losses = losses[skip_start:-skip_end] - accs = accs[skip_start:-skip_end] - fig, ax_loss = plt.subplots() + ax_acc = ax_loss.twinx() ax_loss.plot(lrs, losses, color='red') ax_loss.set_xscale("log") - ax_loss.set_xlabel("Learning rate") + ax_loss.set_xlabel(f"Learning rate") ax_loss.set_ylabel("Loss") - ax_acc = ax_loss.twinx() ax_acc.plot(lrs, accs, color='blue') + ax_acc.set_xscale("log") ax_acc.set_ylabel("Accuracy") print("LR suggestion: steepest gradient") min_grad_idx = None + if skip_end == 0: + lrs = lrs[skip_start:] + losses = losses[skip_start:] + accs = accs[skip_start:] + else: + lrs = lrs[skip_start:-skip_end] + losses = losses[skip_start:-skip_end] + accs = accs[skip_start:-skip_end] + try: min_grad_idx = (np.gradient(np.array(losses))).argmin() except ValueError: @@ -196,8 +198,10 @@ class LRFinder(): ax_loss.legend() -def get_exponential_schedule(optimizer, min_lr, num_epochs, last_epoch=-1): - def lr_lambda(current_epoch: int): - return min_lr + ((current_epoch / num_epochs) ** 10) * (1 - min_lr) +def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): + def lr_lambda(base_lr: float, current_epoch: int): + return (end_lr / base_lr) ** (current_epoch / num_epochs) + + lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] - return LambdaLR(optimizer, lr_lambda, last_epoch) + return LambdaLR(optimizer, lr_lambdas, last_epoch) -- cgit v1.2.3-70-g09d2