From 54d72ba4a8331d822a48bad9e381b47d39598125 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Dec 2022 21:00:34 +0100 Subject: Updated 1-cycle scheduler --- training/lr.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'training/lr.py') diff --git a/training/lr.py b/training/lr.py index c1fa3a0..c0e9b3f 100644 --- a/training/lr.py +++ b/training/lr.py @@ -19,8 +19,8 @@ class LRFinder(): self.val_dataloader = val_dataloader self.loss_fn = loss_fn - self.model_state = copy.deepcopy(model.state_dict()) - self.optimizer_state = copy.deepcopy(optimizer.state_dict()) + # self.model_state = copy.deepcopy(model.state_dict()) + # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): best_loss = None @@ -109,8 +109,8 @@ class LRFinder(): "lr": lr, }) - self.model.load_state_dict(self.model_state) - self.optimizer.load_state_dict(self.optimizer_state) + # self.model.load_state_dict(self.model_state) + # self.optimizer.load_state_dict(self.optimizer_state) if loss > diverge_th * best_loss: print("Stopping early, the loss has diverged") @@ -127,12 +127,14 @@ class LRFinder(): fig, ax_loss = plt.subplots() - ax_loss.plot(lrs, losses, color='red', label='Loss') + ax_loss.plot(lrs, losses, color='red') ax_loss.set_xscale("log") ax_loss.set_xlabel("Learning rate") + ax_loss.set_ylabel("Loss") # ax_acc = ax_loss.twinx() - # ax_acc.plot(lrs, accs, color='blue', label='Accuracy') + # ax_acc.plot(lrs, accs, color='blue') + # ax_acc.set_ylabel("Accuracy") print("LR suggestion: steepest gradient") min_grad_idx = None -- cgit v1.2.3-54-g00ecf