diff options
Diffstat (limited to 'training/lr.py')
| -rw-r--r-- | training/lr.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/training/lr.py b/training/lr.py index c1fa3a0..c0e9b3f 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -19,8 +19,8 @@ class LRFinder(): | |||
| 19 | self.val_dataloader = val_dataloader | 19 | self.val_dataloader = val_dataloader |
| 20 | self.loss_fn = loss_fn | 20 | self.loss_fn = loss_fn |
| 21 | 21 | ||
| 22 | self.model_state = copy.deepcopy(model.state_dict()) | 22 | # self.model_state = copy.deepcopy(model.state_dict()) |
| 23 | self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 23 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) |
| 24 | 24 | ||
| 25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): | 25 | def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): |
| 26 | best_loss = None | 26 | best_loss = None |
| @@ -109,8 +109,8 @@ class LRFinder(): | |||
| 109 | "lr": lr, | 109 | "lr": lr, |
| 110 | }) | 110 | }) |
| 111 | 111 | ||
| 112 | self.model.load_state_dict(self.model_state) | 112 | # self.model.load_state_dict(self.model_state) |
| 113 | self.optimizer.load_state_dict(self.optimizer_state) | 113 | # self.optimizer.load_state_dict(self.optimizer_state) |
| 114 | 114 | ||
| 115 | if loss > diverge_th * best_loss: | 115 | if loss > diverge_th * best_loss: |
| 116 | print("Stopping early, the loss has diverged") | 116 | print("Stopping early, the loss has diverged") |
| @@ -127,12 +127,14 @@ class LRFinder(): | |||
| 127 | 127 | ||
| 128 | fig, ax_loss = plt.subplots() | 128 | fig, ax_loss = plt.subplots() |
| 129 | 129 | ||
| 130 | ax_loss.plot(lrs, losses, color='red', label='Loss') | 130 | ax_loss.plot(lrs, losses, color='red') |
| 131 | ax_loss.set_xscale("log") | 131 | ax_loss.set_xscale("log") |
| 132 | ax_loss.set_xlabel("Learning rate") | 132 | ax_loss.set_xlabel("Learning rate") |
| 133 | ax_loss.set_ylabel("Loss") | ||
| 133 | 134 | ||
| 134 | # ax_acc = ax_loss.twinx() | 135 | # ax_acc = ax_loss.twinx() |
| 135 | # ax_acc.plot(lrs, accs, color='blue', label='Accuracy') | 136 | # ax_acc.plot(lrs, accs, color='blue') |
| 137 | # ax_acc.set_ylabel("Accuracy") | ||
| 136 | 138 | ||
| 137 | print("LR suggestion: steepest gradient") | 139 | print("LR suggestion: steepest gradient") |
| 138 | min_grad_idx = None | 140 | min_grad_idx = None |
