From 33e7d2ed37e32657ca94d92815043026c4cea7c0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 10 Jan 2023 09:22:02 +0100 Subject: Added arg to disable tag shuffling --- training/lr.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'training') diff --git a/training/lr.py b/training/lr.py index 68e0f72..dfb1743 100644 --- a/training/lr.py +++ b/training/lr.py @@ -48,7 +48,7 @@ class LRFinder(): skip_start: int = 10, skip_end: int = 5, num_epochs: int = 100, - num_train_batches: int = 1, + num_train_batches: int = math.inf, num_val_batches: int = math.inf, smooth_f: float = 0.05, ): @@ -156,6 +156,15 @@ class LRFinder(): # self.model.load_state_dict(self.model_state) # self.optimizer.load_state_dict(self.optimizer_state) + if skip_end == 0: + lrs = lrs[skip_start:] + losses = losses[skip_start:] + accs = accs[skip_start:] + else: + lrs = lrs[skip_start:-skip_end] + losses = losses[skip_start:-skip_end] + accs = accs[skip_start:-skip_end] + fig, ax_loss = plt.subplots() ax_acc = ax_loss.twinx() @@ -171,15 +180,6 @@ class LRFinder(): print("LR suggestion: steepest gradient") min_grad_idx = None - if skip_end == 0: - lrs = lrs[skip_start:] - losses = losses[skip_start:] - accs = accs[skip_start:] - else: - lrs = lrs[skip_start:-skip_end] - losses = losses[skip_start:-skip_end] - accs = accs[skip_start:-skip_end] - try: min_grad_idx = np.gradient(np.array(losses)).argmin() except ValueError: -- cgit v1.2.3-70-g09d2