From 181d56a0af567309a6fda4bfc4e2243ad5f4ca06 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 18:55:41 +0100 Subject: Fix LR finder --- training/lr.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/training/lr.py b/training/lr.py index 3cdf994..c765150 100644 --- a/training/lr.py +++ b/training/lr.py @@ -51,7 +51,7 @@ class LRFinder(): num_train_batches: int = 1, num_val_batches: int = math.inf, smooth_f: float = 0.05, - diverge_th: int = 5 + diverge_th: int = 5, ): best_loss = None best_acc = None @@ -157,10 +157,6 @@ class LRFinder(): # self.model.load_state_dict(self.model_state) # self.optimizer.load_state_dict(self.optimizer_state) - if loss > diverge_th * best_loss: - print("Stopping early, the loss has diverged") - break - fig, ax_loss = plt.subplots() ax_acc = ax_loss.twinx() @@ -186,14 +182,21 @@ class LRFinder(): accs = accs[skip_start:-skip_end] try: - min_grad_idx = (np.gradient(np.array(losses))).argmin() + min_grad_idx = np.gradient(np.array(losses)).argmin() + except ValueError: + print( + "Failed to compute the gradients, there might not be enough points." + ) + + try: + max_val_idx = np.array(accs).argmax() except ValueError: print( "Failed to compute the gradients, there might not be enough points." ) if min_grad_idx is not None: - print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) + print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx])) ax_loss.scatter( lrs[min_grad_idx], losses[min_grad_idx], @@ -205,6 +208,19 @@ class LRFinder(): ) ax_loss.legend() + if max_val_idx is not None: + print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx])) + ax_acc.scatter( + lrs[max_val_idx], + accs[max_val_idx], + s=75, + marker="o", + color="blue", + zorder=3, + label="maximum", + ) + ax_acc.legend() + def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): def lr_lambda(base_lr: float, current_epoch: int): -- cgit v1.2.3-70-g09d2